Skip to content

Commit

Permalink
886 add MetricsSaver handler to save metrics and details into files (#…
Browse files Browse the repository at this point in the history
…1497)

* [DLMED] add IterationHandler refer to the EpochHandler in ignite

Signed-off-by: Nic Ma <[email protected]>

* [MONAI] python code formatting

Signed-off-by: monai-bot <[email protected]>

* [DLMED] fix flake8 issue

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] fix the multi-gpu issue

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] fix typo

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] fix distributed tests

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] fix flake8 issue

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] add engine to metrics

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] share metric details in engine

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] add metrics report

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] add average value to report

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] add summary report

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] add docs

Signed-off-by: Nic Ma <[email protected]>

* [MONAI] python code formatting

Signed-off-by: monai-bot <[email protected]>

* [DLMED] fix flake8 issue

Signed-off-by: Nic Ma <[email protected]>

* [MONAI] python code formatting

Signed-off-by: monai-bot <[email protected]>

* [DLMED] add unit tests and distributed tests

Signed-off-by: Nic Ma <[email protected]>

* [MONAI] python code formatting

Signed-off-by: monai-bot <[email protected]>

* [DLMED] fix flake8 issue

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] fix typo

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] remove from min_tests

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] remove useless var

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] add skip flag

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] update according to comments

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] add dist tests

Signed-off-by: Nic Ma <[email protected]>

* [MONAI] python code formatting

Signed-off-by: monai-bot <[email protected]>

* [DLMED] fix flake8 issue

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] enhance some unit tests

Signed-off-by: Nic Ma <[email protected]>

* [MONAI] python code formatting

Signed-off-by: monai-bot <[email protected]>

* [DLMED] remove from min_tests

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] change to standlone APIs to write files

Signed-off-by: Nic Ma <[email protected]>

* [MONAI] python code formatting

Signed-off-by: monai-bot <[email protected]>

* [DLMED] add file type check

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] add output_type arg

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] develop standlone API

Signed-off-by: Nic Ma <[email protected]>

* [MONAI] python code formatting

Signed-off-by: monai-bot <[email protected]>

* [DLMED] fix flake8 issue

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] fix flake8 error

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] fix min test

Signed-off-by: Nic Ma <[email protected]>

Co-authored-by: monai-bot <[email protected]>
  • Loading branch information
Nic-Ma and monai-bot authored Jan 29, 2021
1 parent 56c88f6 commit b3d063c
Show file tree
Hide file tree
Showing 20 changed files with 682 additions and 33 deletions.
7 changes: 7 additions & 0 deletions docs/source/handlers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ Model checkpoint saver
.. autoclass:: CheckpointSaver
:members:


Metrics saver
-------------
.. autoclass:: MetricsSaver
:members:


CSV saver
---------
.. autoclass:: ClassificationSaver
Expand Down
1 change: 1 addition & 0 deletions monai/engines/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def set_sampler_epoch(engine: Engine):
output=None,
batch=None,
metrics={},
metric_details={},
dataloader=None,
device=device,
key_metric_name=None, # we can set many metrics, only use key_metric to compare and save the best model
Expand Down
3 changes: 2 additions & 1 deletion monai/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
from .lr_schedule_handler import LrScheduleHandler
from .mean_dice import MeanDice
from .metric_logger import MetricLogger
from .metrics_saver import MetricsSaver
from .roc_auc import ROCAUC
from .segmentation_saver import SegmentationSaver
from .smartcache_handler import SmartCacheHandler
from .stats_handler import StatsHandler
from .surface_distance import SurfaceDistance
from .tensorboard_handlers import TensorBoardImageHandler, TensorBoardStatsHandler
from .utils import all_gather, stopping_fn_from_loss, stopping_fn_from_metric
from .utils import evenly_divisible_all_gather, stopping_fn_from_loss, stopping_fn_from_metric, write_metrics_reports
from .validation_handler import ValidationHandler
10 changes: 9 additions & 1 deletion monai/handlers/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(
metric_name: str = "hit_rate",
output_transform: Callable = lambda x: x,
device: Optional[torch.device] = None,
save_details: bool = True,
) -> None:
"""
Expand All @@ -44,6 +45,8 @@ def __init__(
and you can also input those names instead.
output_transform: transform the ignite.engine.state.output into [y_pred, y] pair.
device: device specification in case of distributed computation usage.
save_details: whether to save metric computation details per image, for example: TP/TN/FP/FN of every image.
default to True, will save to `engine.state.metric_details` dict with the metric name as key.
See also:
:py:meth:`monai.metrics.confusion_matrix`
Expand All @@ -55,7 +58,12 @@ def __init__(
reduction=MetricReduction.NONE,
)
self.metric_name = metric_name
super().__init__(metric_fn=metric_fn, output_transform=output_transform, device=device)
super().__init__(
metric_fn=metric_fn,
output_transform=output_transform,
device=device,
save_details=save_details,
)

def _reduce(self, scores) -> Any:
confusion_matrix, _ = do_metric_reduction(scores, MetricReduction.MEAN)
Expand Down
10 changes: 9 additions & 1 deletion monai/handlers/hausdorff_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
directed: bool = False,
output_transform: Callable = lambda x: x,
device: Optional[torch.device] = None,
save_details: bool = True,
) -> None:
"""
Expand All @@ -45,6 +46,8 @@ def __init__(
directed: whether to calculate directed Hausdorff distance. Defaults to ``False``.
output_transform: transform the ignite.engine.state.output into [y_pred, y] pair.
device: device specification in case of distributed computation usage.
save_details: whether to save metric computation details per image, for example: hausdorff distance
of every image. default to True, will save to `engine.state.metric_details` dict with the metric name as key.
"""
super().__init__(output_transform, device=device)
Expand All @@ -55,4 +58,9 @@ def __init__(
directed=directed,
reduction=MetricReduction.NONE,
)
super().__init__(metric_fn=metric_fn, output_transform=output_transform, device=device)
super().__init__(
metric_fn=metric_fn,
output_transform=output_transform,
device=device,
save_details=save_details,
)
46 changes: 36 additions & 10 deletions monai/handlers/iteration_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, List, Optional, Sequence
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence

import torch

from monai.handlers.utils import evenly_divisible_all_gather
from monai.metrics import do_metric_reduction
from monai.utils import MetricReduction, exact_version, optional_import

NotComputableError, _ = optional_import("ignite.exceptions", "0.4.2", exact_version, "NotComputableError")
idist, _ = optional_import("ignite", "0.4.2", exact_version, "distributed")
Metric, _ = optional_import("ignite.metrics", "0.4.2", exact_version, "Metric")
reinit__is_reduced, _ = optional_import("ignite.metrics.metric", "0.4.2", exact_version, "reinit__is_reduced")
if TYPE_CHECKING:
from ignite.engine import Engine
else:
Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine")


class IterationMetric(Metric): # type: ignore[valid-type, misc] # due to optional_import
Expand All @@ -33,6 +37,8 @@ class IterationMetric(Metric): # type: ignore[valid-type, misc] # due to option
expect to return a Tensor with shape (batch, channel, ...) or tuple (Tensor, not_nans).
output_transform: transform the ignite.engine.state.output into [y_pred, y] pair.
device: device specification in case of distributed computation usage.
save_details: whether to save metric computation details per image, for example: mean_dice of every image.
default to True, will save to `engine.state.metric_details` dict with the metric name as key.
"""

Expand All @@ -41,10 +47,14 @@ def __init__(
metric_fn: Callable,
output_transform: Callable = lambda x: x,
device: Optional[torch.device] = None,
save_details: bool = True,
) -> None:
self._is_reduced: bool = False
self.metric_fn = metric_fn
self.save_details = save_details
self._scores: List = []
self._engine: Optional[Engine] = None
self._name: Optional[str] = None
super().__init__(output_transform, device=device)

@reinit__is_reduced
Expand Down Expand Up @@ -79,17 +89,16 @@ def compute(self) -> Any:

ws = idist.get_world_size()
if ws > 1 and not self._is_reduced:
# make sure the _scores is evenly-divisible on multi-GPUs
length = _scores.shape[0]
max_len = max(idist.all_gather(length)).item()
if length < max_len:
size = [max_len - length] + list(_scores.shape[1:])
_scores = torch.cat([_scores, _scores.new_full(size, float("NaN"))], dim=0)

# all gather across all processes
_scores = idist.all_gather(_scores)
_scores = evenly_divisible_all_gather(data=_scores)
self._is_reduced = True

# save score of every image into engine.state for other components
if self.save_details:
if self._engine is None or self._name is None:
raise RuntimeError("plesae call the attach() function to connect expected engine first.")
self._engine.state.metric_details[self._name] = _scores

result: torch.Tensor = torch.zeros(1)
if idist.get_rank() == 0:
# run compute_fn on zero rank only
Expand All @@ -103,3 +112,20 @@ def compute(self) -> Any:

def _reduce(self, scores) -> Any:
return do_metric_reduction(scores, MetricReduction.MEAN)[0]

def attach(self, engine: Engine, name: str) -> None:
"""
Attaches current metric to provided engine. On the end of engine's run,
`engine.state.metrics` dictionary will contain computed metric's value under provided name.
Args:
engine: the engine to which the metric must be attached.
name: the name of the metric to attach.
"""
super().attach(engine=engine, name=name)
# FIXME: record engine for communication, ignite will support it in the future version soon
self._engine = engine
self._name = name
if self.save_details and not hasattr(engine.state, "metric_details"):
engine.state.metric_details = {}
10 changes: 9 additions & 1 deletion monai/handlers/mean_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
include_background: bool = True,
output_transform: Callable = lambda x: x,
device: Optional[torch.device] = None,
save_details: bool = True,
) -> None:
"""
Expand All @@ -36,6 +37,8 @@ def __init__(
Defaults to True.
output_transform: transform the ignite.engine.state.output into [y_pred, y] pair.
device: device specification in case of distributed computation usage.
save_details: whether to save metric computation details per image, for example: mean dice of every image.
default to True, will save to `engine.state.metric_details` dict with the metric name as key.
See also:
:py:meth:`monai.metrics.meandice.compute_meandice`
Expand All @@ -44,4 +47,9 @@ def __init__(
include_background=include_background,
reduction=MetricReduction.NONE,
)
super().__init__(metric_fn=metric_fn, output_transform=output_transform, device=device)
super().__init__(
metric_fn=metric_fn,
output_transform=output_transform,
device=device,
save_details=save_details,
)
137 changes: 137 additions & 0 deletions monai/handlers/metrics_saver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Callable, List, Optional, Sequence, Union

from monai.handlers.utils import write_metrics_reports
from monai.utils import ensure_tuple, exact_version, optional_import
from monai.utils.module import get_torch_version_tuple

Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events")
idist, _ = optional_import("ignite", "0.4.2", exact_version, "distributed")
if TYPE_CHECKING:
from ignite.engine import Engine
else:
Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine")


class MetricsSaver:
"""
ignite handler to save metrics values and details into expected files.
Args:
save_dir: directory to save the metrics and metric details.
metrics: expected final metrics to save into files, can be: None, "*" or list of strings.
None - don't save any metrics into files.
"*" - save all the existing metrics in `engine.state.metrics` dict into separate files.
list of strings - specify the expected metrics to save.
default to "*" to save all the metrics into `metrics.csv`.
metric_details: expected metric details to save into files, for example: mean dice
of every channel of every image in the validation dataset.
the data in `engine.state.metric_details` must contain at least 2 dims: (batch, classes, ...),
if not, will unsequeeze to 2 dims.
this arg can be: None, "*" or list of strings.
None - don't save any metrics into files.
"*" - save all the existing metrics in `engine.state.metric_details` dict into separate files.
list of strings - specify the expected metrics to save.
if not None, every metric will save a separate `{metric name}_raw.csv` file.
batch_transform: callable function to extract the meta_dict from input batch data if saving metric details.
used to extract filenames from input dict data.
summary_ops: expected computation operations to generate the summary report.
it can be: None, "*" or list of strings.
None - don't generate summary report for every expected metric_details
"*" - generate summary report for every metric_details with all the supported operations.
list of strings - generate summary report for every metric_details with specified operations, they
should be within this list: [`mean`, `median`, `max`, `min`, `90percent`, `std`].
default to None.
save_rank: only the handler on specified rank will save to files in multi-gpus validation, default to 0.
delimiter: the delimiter charactor in CSV file, default to "\t".
output_type: expected output file type, supported types: ["csv"], default to "csv".
"""

def __init__(
self,
save_dir: str,
metrics: Optional[Union[str, Sequence[str]]] = "*",
metric_details: Optional[Union[str, Sequence[str]]] = None,
batch_transform: Callable = lambda x: x,
summary_ops: Optional[Union[str, Sequence[str]]] = None,
save_rank: int = 0,
delimiter: str = "\t",
output_type: str = "csv",
) -> None:
self.save_dir = save_dir
self.metrics = ensure_tuple(metrics) if metrics is not None else None
self.metric_details = ensure_tuple(metric_details) if metric_details is not None else None
self.batch_transform = batch_transform
self.summary_ops = ensure_tuple(summary_ops) if summary_ops is not None else None
self.save_rank = save_rank
self.deli = delimiter
self.output_type = output_type
self._filenames: List[str] = []

def attach(self, engine: Engine) -> None:
"""
Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
engine.add_event_handler(Events.STARTED, self._started)
engine.add_event_handler(Events.ITERATION_COMPLETED, self._get_filenames)
engine.add_event_handler(Events.EPOCH_COMPLETED, self)

def _started(self, engine: Engine) -> None:
self._filenames = []

def _get_filenames(self, engine: Engine) -> None:
if self.metric_details is not None:
_filenames = list(ensure_tuple(self.batch_transform(engine.state.batch)["filename_or_obj"]))
self._filenames += _filenames

def __call__(self, engine: Engine) -> None:
"""
Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
ws = idist.get_world_size()
if self.save_rank >= ws:
raise ValueError("target rank is greater than the distributed group size.")

_images = self._filenames
if ws > 1:
_filenames = self.deli.join(_images)
if get_torch_version_tuple() > (1, 6, 0):
# all gather across all processes
_filenames = self.deli.join(idist.all_gather(_filenames))
else:
raise RuntimeError("MetricsSaver can not save metric details in distributed mode with PyTorch < 1.7.0.")
_images = _filenames.split(self.deli)

# only save metrics to file in specified rank
if idist.get_rank() == self.save_rank:
_metrics = {}
if self.metrics is not None and len(engine.state.metrics) > 0:
_metrics = {k: v for k, v in engine.state.metrics.items() if k in self.metrics or "*" in self.metrics}
_metric_details = {}
if self.metric_details is not None and len(engine.state.metric_details) > 0:
for k, v in engine.state.metric_details.items():
if k in self.metric_details or "*" in self.metric_details:
_metric_details[k] = v

write_metrics_reports(
save_dir=self.save_dir,
images=_images,
metrics=_metrics,
metric_details=_metric_details,
summary_ops=self.summary_ops,
deli=self.deli,
output_type=self.output_type,
)
10 changes: 9 additions & 1 deletion monai/handlers/surface_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
distance_metric: str = "euclidean",
output_transform: Callable = lambda x: x,
device: Optional[torch.device] = None,
save_details: bool = True,
) -> None:
"""
Expand All @@ -42,6 +43,8 @@ def __init__(
the metric used to compute surface distance. Defaults to ``"euclidean"``.
output_transform: transform the ignite.engine.state.output into [y_pred, y] pair.
device: device specification in case of distributed computation usage.
save_details: whether to save metric computation details per image, for example: surface dice
of every image. default to True, will save to `engine.state.metric_details` dict with the metric name as key.
"""
metric_fn = SurfaceDistanceMetric(
Expand All @@ -50,4 +53,9 @@ def __init__(
distance_metric=distance_metric,
reduction=MetricReduction.NONE,
)
super().__init__(metric_fn=metric_fn, output_transform=output_transform, device=device)
super().__init__(
metric_fn=metric_fn,
output_transform=output_transform,
device=device,
save_details=save_details,
)
Loading

0 comments on commit b3d063c

Please sign in to comment.