From 8a1589e59260cb77db0da57e64c0499bc32c7aea Mon Sep 17 00:00:00 2001 From: camillebrianceau <57992134+camillebrianceau@users.noreply.github.com> Date: Thu, 26 Sep 2024 13:39:56 +0200 Subject: [PATCH] =?UTF-8?q?Sortir=20les=20=C3=A9tapes=20de=20validation=20?= =?UTF-8?q?du=20MapsManager=20(#657)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * first try to create a validator * test to integrate validator config --- clinicadl/API_test.py | 17 + clinicadl/maps_manager/maps_manager.py | 273 -------------- clinicadl/predict/predict_manager.py | 27 +- clinicadl/trainer/tasks_utils.py | 169 --------- clinicadl/trainer/trainer.py | 54 ++- clinicadl/validator/config.py | 48 +++ clinicadl/validator/validator.py | 498 +++++++++++++++++++++++++ 7 files changed, 614 insertions(+), 472 deletions(-) create mode 100644 clinicadl/API_test.py create mode 100644 clinicadl/validator/config.py create mode 100644 clinicadl/validator/validator.py diff --git a/clinicadl/API_test.py b/clinicadl/API_test.py new file mode 100644 index 000000000..a7240c92d --- /dev/null +++ b/clinicadl/API_test.py @@ -0,0 +1,17 @@ +from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig +from clinicadl.prepare_data.prepare_data import DeepLearningPrepareData +from clinicadl.trainer.config.classification import ClassificationConfig +from clinicadl.trainer.trainer import Trainer +from clinicadl.utils.enum import ExtractionMethod, Preprocessing, Task +from clinicadl.utils.iotools.train_utils import merge_cli_and_config_file_options + +image_config = CapsDatasetConfig.from_preprocessing_and_extraction_method( + extraction=ExtractionMethod.IMAGE, + preprocessing_type=Preprocessing.T1_LINEAR, +) + +DeepLearningPrepareData(image_config) + +config = ClassificationConfig() +trainer = Trainer(config) +trainer.train(split_list=config.cross_validation.split, overwrite=True) diff --git a/clinicadl/maps_manager/maps_manager.py b/clinicadl/maps_manager/maps_manager.py index 73b6430eb..3b32486b5 100644 --- a/clinicadl/maps_manager/maps_manager.py +++ b/clinicadl/maps_manager/maps_manager.py @@ -7,8 +7,6 @@ import pandas as pd import torch -import torch.distributed as dist -from torch.amp import autocast from clinicadl.caps_dataset.caps_dataset_utils import read_json from clinicadl.caps_dataset.data import ( @@ -17,7 +15,6 @@ from clinicadl.metrics.metric_module import MetricModule from clinicadl.metrics.utils import ( check_selection_metric, - find_selection_metrics, ) from clinicadl.predict.utils import get_prediction from clinicadl.trainer.tasks_utils import ( @@ -25,8 +22,6 @@ evaluation_metrics, generate_label_code, output_size, - test, - test_da, ) from clinicadl.transforms.config import TransformsConfig from clinicadl.utils import cluster @@ -149,274 +144,6 @@ def __getattr__(self, name): ################################### # High-level functions templates # ################################### - def _test_loader( - self, - dataloader, - criterion, - data_group: str, - split: int, - selection_metrics, - use_labels=True, - gpu=None, - amp=False, - network=None, - report_ci=True, - ): - """ - Launches the testing task on a dataset wrapped by a DataLoader and writes prediction TSV files. - - Args: - dataloader (torch.utils.data.DataLoader): DataLoader wrapping the test CapsDataset. - criterion (torch.nn.modules.loss._Loss): optimization criterion used during training. - data_group (str): name of the data group used for the testing task. - split (int): Index of the split used to train the model tested. - selection_metrics (list[str]): List of metrics used to select the best models which are tested. - use_labels (bool): If True, the labels must exist in test meta-data and metrics are computed. - gpu (bool): If given, a new value for the device of the model will be computed. - amp (bool): If enabled, uses Automatic Mixed Precision (requires GPU usage). - network (int): Index of the network tested (only used in multi-network setting). - """ - for selection_metric in selection_metrics: - if cluster.master: - log_dir = ( - self.maps_path - / f"{self.split_name}-{split}" - / f"best-{selection_metric}" - / data_group - ) - self.write_description_log( - log_dir, - data_group, - dataloader.dataset.config.data.caps_dict, - dataloader.dataset.config.data.data_df, - ) - - # load the best trained model during the training - model, _ = self._init_model( - transfer_path=self.maps_path, - split=split, - transfer_selection=selection_metric, - gpu=gpu, - network=network, - ) - model = DDP(model, fsdp=self.fully_sharded_data_parallel, amp=self.amp) - - prediction_df, metrics = test( - mode=self.mode, - metrics_module=self.metrics_module, - n_classes=self.n_classes, - network_task=self.network_task, - model=model, - dataloader=dataloader, - criterion=criterion, - use_labels=use_labels, - amp=amp, - report_ci=report_ci, - ) - if use_labels: - if network is not None: - metrics[f"{self.mode}_id"] = network - - loss_to_log = ( - metrics["Metric_values"][-1] if report_ci else metrics["loss"] - ) - - logger.info( - f"{self.mode} level {data_group} loss is {loss_to_log} for model selected on {selection_metric}" - ) - - if cluster.master: - # Replace here - self._mode_level_to_tsv( - prediction_df, - metrics, - split, - selection_metric, - data_group=data_group, - ) - - def _test_loader_ssda( - self, - dataloader, - criterion, - alpha, - data_group, - split, - selection_metrics, - use_labels=True, - gpu=None, - network=None, - target=False, - report_ci=True, - ): - """ - Launches the testing task on a dataset wrapped by a DataLoader and writes prediction TSV files. - - Args: - dataloader (torch.utils.data.DataLoader): DataLoader wrapping the test CapsDataset. - criterion (torch.nn.modules.loss._Loss): optimization criterion used during training. - data_group (str): name of the data group used for the testing task. - split (int): Index of the split used to train the model tested. - selection_metrics (list[str]): List of metrics used to select the best models which are tested. - use_labels (bool): If True, the labels must exist in test meta-data and metrics are computed. - gpu (bool): If given, a new value for the device of the model will be computed. - network (int): Index of the network tested (only used in multi-network setting). - """ - for selection_metric in selection_metrics: - log_dir = ( - self.maps_path - / f"{self.split_name}-{split}" - / f"best-{selection_metric}" - / data_group - ) - self.write_description_log( - log_dir, - data_group, - dataloader.dataset.caps_dict, - dataloader.dataset.df, - ) - - # load the best trained model during the training - model, _ = self._init_model( - transfer_path=self.maps_path, - split=split, - transfer_selection=selection_metric, - gpu=gpu, - network=network, - ) - prediction_df, metrics = test_da( - self.network_task, - model, - dataloader, - criterion, - target=target, - report_ci=report_ci, - ) - if use_labels: - if network is not None: - metrics[f"{self.mode}_id"] = network - - if report_ci: - loss_to_log = metrics["Metric_values"][-1] - else: - loss_to_log = metrics["loss"] - - logger.info( - f"{self.mode} level {data_group} loss is {loss_to_log} for model selected on {selection_metric}" - ) - - # Replace here - self._mode_level_to_tsv( - prediction_df, metrics, split, selection_metric, data_group=data_group - ) - - @torch.no_grad() - def _compute_output_tensors( - self, - dataset, - data_group, - split, - selection_metrics, - nb_images=None, - gpu=None, - network=None, - ): - """ - Compute the output tensors and saves them in the MAPS. - - Args: - dataset (clinicadl.caps_dataset.data.CapsDataset): wrapper of the data set. - data_group (str): name of the data group used for the task. - split (int): split number. - selection_metrics (list[str]): metrics used for model selection. - nb_images (int): number of full images to write. Default computes the outputs of the whole data set. - gpu (bool): If given, a new value for the device of the model will be computed. - network (int): Index of the network tested (only used in multi-network setting). - """ - for selection_metric in selection_metrics: - # load the best trained model during the training - model, _ = self._init_model( - transfer_path=self.maps_path, - split=split, - transfer_selection=selection_metric, - gpu=gpu, - network=network, - nb_unfrozen_layer=self.nb_unfrozen_layer, - ) - model = DDP(model, fsdp=self.fully_sharded_data_parallel, amp=self.amp) - model.eval() - - tensor_path = ( - self.maps_path - / f"{self.split_name}-{split}" - / f"best-{selection_metric}" - / data_group - / "tensors" - ) - if cluster.master: - tensor_path.mkdir(parents=True, exist_ok=True) - dist.barrier() - - if nb_images is None: # Compute outputs for the whole data set - nb_modes = len(dataset) - else: - nb_modes = nb_images * dataset.elem_per_image - - for i in [ - *range(cluster.rank, nb_modes, cluster.world_size), - *range(int(nb_modes % cluster.world_size <= cluster.rank)), - ]: - data = dataset[i] - image = data["image"] - x = image.unsqueeze(0).to(model.device) - with autocast("cuda", enabled=self.std_amp): - output = model(x) - output = output.squeeze(0).cpu().float() - participant_id = data["participant_id"] - session_id = data["session_id"] - mode_id = data[f"{self.mode}_id"] - input_filename = ( - f"{participant_id}_{session_id}_{self.mode}-{mode_id}_input.pt" - ) - output_filename = ( - f"{participant_id}_{session_id}_{self.mode}-{mode_id}_output.pt" - ) - torch.save(image, tensor_path / input_filename) - torch.save(output, tensor_path / output_filename) - logger.debug(f"File saved at {[input_filename, output_filename]}") - - def _ensemble_prediction( - self, - data_group, - split, - selection_metrics, - use_labels=True, - skip_leak_check=False, - ): - """Computes the results on the image-level.""" - - if not selection_metrics: - selection_metrics = find_selection_metrics( - self.maps_path, self.split_name, split - ) - - for selection_metric in selection_metrics: - ##################### - # Soft voting - if self.num_networks > 1 and not skip_leak_check: - self._ensemble_to_tsv( - split, - selection=selection_metric, - data_group=data_group, - use_labels=use_labels, - ) - elif self.mode != "image" and not skip_leak_check: - self._mode_to_image_tsv( - split, - selection=selection_metric, - data_group=data_group, - use_labels=use_labels, - ) ############################### # Checks # diff --git a/clinicadl/predict/predict_manager.py b/clinicadl/predict/predict_manager.py index 879ef0e54..c197a96de 100644 --- a/clinicadl/predict/predict_manager.py +++ b/clinicadl/predict/predict_manager.py @@ -29,6 +29,7 @@ ClinicaDLDataLeakageError, MAPSError, ) +from clinicadl.validator.validator import Validator logger = getLogger("clinicadl.predict_manager") level_list: List[str] = ["warning", "info", "debug"] @@ -38,6 +39,7 @@ class PredictManager: def __init__(self, _config: Union[PredictConfig, InterpretConfig]) -> None: self.maps_manager = MapsManager(_config.maps_dir) self._config = _config + self.validator = Validator() def predict( self, @@ -183,7 +185,8 @@ def predict( split_selection_metrics, ) if cluster.master: - self.maps_manager._ensemble_prediction( + self.validator._ensemble_prediction( + self.maps_manager, self._config.data_group, split, self._config.selection_metrics, @@ -288,12 +291,13 @@ def _predict_multi( if self._config.n_proc is not None else self.maps_manager.n_proc, ) - self.maps_manager._test_loader( - test_loader, - criterion, - self._config.data_group, - split, - split_selection_metrics, + self.validator._test_loader( + maps_manager=self.maps_manager, + dataloader=test_loader, + criterion=criterion, + data_group=self._config.data_group, + split=split, + selection_metrics=split_selection_metrics, use_labels=self._config.use_labels, gpu=self._config.gpu, amp=self._config.amp, @@ -301,7 +305,8 @@ def _predict_multi( ) if self._config.save_tensor: logger.debug("Saving tensors") - self.maps_manager._compute_output_tensors( + self.validator._compute_output_tensors( + self.maps_manager, data_test, self._config.data_group, split, @@ -416,7 +421,8 @@ def _predict_single( if self._config.n_proc is not None else self.maps_manager.n_proc, ) - self.maps_manager._test_loader( + self.validator._test_loader( + self.maps_manager, test_loader, criterion, self._config.data_group, @@ -428,7 +434,8 @@ def _predict_single( ) if self._config.save_tensor: logger.debug("Saving tensors") - self.maps_manager._compute_output_tensors( + self.validator._compute_output_tensors( + self.maps_manager, data_test, self._config.data_group, split, diff --git a/clinicadl/trainer/tasks_utils.py b/clinicadl/trainer/tasks_utils.py index 93a652aa8..dc28d0acd 100644 --- a/clinicadl/trainer/tasks_utils.py +++ b/clinicadl/trainer/tasks_utils.py @@ -207,175 +207,6 @@ def evaluation_metrics(network_task: Union[str, Task]): raise ValueError("Unknown network task") -def test( - mode: str, - metrics_module: MetricModule, - n_classes: int, - network_task, - model: Network, - dataloader: DataLoader, - criterion: _Loss, - use_labels: bool = True, - amp: bool = False, - report_ci=False, -) -> Tuple[pd.DataFrame, Dict[str, float]]: - """ - Computes the predictions and evaluation metrics. - - Parameters - ---------- - model: Network - The model trained. - dataloader: DataLoader - Wrapper of a CapsDataset. - criterion: _Loss - Function to calculate the loss. - use_labels: bool - If True the true_label will be written in output DataFrame - and metrics dict will be created. - amp: bool - If True, enables Pytorch's automatic mixed precision. - - Returns - ------- - the results and metrics on the image level. - """ - model.eval() - dataloader.dataset.eval() - - results_df = pd.DataFrame(columns=columns(network_task, mode, n_classes)) - total_loss = {} - with torch.no_grad(): - for i, data in enumerate(dataloader): - # initialize the loss list to save the loss components - with autocast("cuda", enabled=amp): - outputs, loss_dict = model(data, criterion, use_labels=use_labels) - - if i == 0: - for loss_component in loss_dict.keys(): - total_loss[loss_component] = 0 - for loss_component in total_loss.keys(): - total_loss[loss_component] += loss_dict[loss_component].float() - - # Generate detailed DataFrame - for idx in range(len(data["participant_id"])): - row = generate_test_row( - network_task, - mode, - metrics_module, - n_classes, - idx, - data, - outputs.float(), - ) - row_df = pd.DataFrame( - row, columns=columns(network_task, mode, n_classes) - ) - results_df = pd.concat([results_df, row_df]) - - del outputs, loss_dict - dataframes = [None] * dist.get_world_size() - dist.gather_object(results_df, dataframes if dist.get_rank() == 0 else None, dst=0) - if dist.get_rank() == 0: - results_df = pd.concat(dataframes) - del dataframes - results_df.reset_index(inplace=True, drop=True) - - if not use_labels: - metrics_dict = None - else: - metrics_dict = compute_metrics( - network_task, results_df, metrics_module, report_ci=report_ci - ) - for loss_component in total_loss.keys(): - dist.reduce(total_loss[loss_component], dst=0) - loss_value = total_loss[loss_component].item() / cluster.world_size - - if report_ci: - metrics_dict["Metric_names"].append(loss_component) - metrics_dict["Metric_values"].append(loss_value) - metrics_dict["Lower_CI"].append("N/A") - metrics_dict["Upper_CI"].append("N/A") - metrics_dict["SE"].append("N/A") - - else: - metrics_dict[loss_component] = loss_value - - torch.cuda.empty_cache() - - return results_df, metrics_dict - - -def test_da( - mode: str, - metrics_module: MetricModule, - n_classes: int, - network_task: Union[str, Task], - model: Network, - dataloader: DataLoader, - criterion: _Loss, - alpha: float = 0, - use_labels: bool = True, - target: bool = True, - report_ci=False, -) -> Tuple[pd.DataFrame, Dict[str, float]]: - """ - Computes the predictions and evaluation metrics. - - Args: - model: the model trained. - dataloader: wrapper of a CapsDataset. - criterion: function to calculate the loss. - use_labels: If True the true_label will be written in output DataFrame - and metrics dict will be created. - Returns: - the results and metrics on the image level. - """ - model.eval() - dataloader.dataset.eval() - results_df = pd.DataFrame(columns=columns(network_task, mode, n_classes)) - total_loss = 0 - with torch.no_grad(): - for i, data in enumerate(dataloader): - outputs, loss_dict = model.compute_outputs_and_loss_test( - data, criterion, alpha, target - ) - total_loss += loss_dict["loss"].item() - - # Generate detailed DataFrame - for idx in range(len(data["participant_id"])): - row = generate_test_row( - network_task, mode, metrics_module, n_classes, idx, data, outputs - ) - row_df = pd.DataFrame( - row, columns=columns(network_task, mode, n_classes) - ) - results_df = pd.concat([results_df, row_df]) - - del outputs, loss_dict - results_df.reset_index(inplace=True, drop=True) - - if not use_labels: - metrics_dict = None - else: - metrics_dict = compute_metrics( - network_task, results_df, metrics_module, report_ci=report_ci - ) - if report_ci: - metrics_dict["Metric_names"].append("loss") - metrics_dict["Metric_values"].append(total_loss) - metrics_dict["Lower_CI"].append("N/A") - metrics_dict["Upper_CI"].append("N/A") - metrics_dict["SE"].append("N/A") - - else: - metrics_dict["loss"] = total_loss - - torch.cuda.empty_cache() - - return results_df, metrics_dict - - def columns(network_task: Union[str, Task], mode: str, n_classes: Optional[int] = None): """ List of the columns' names in the TSV file containing the predictions. diff --git a/clinicadl/trainer/trainer.py b/clinicadl/trainer/trainer.py index 16d2d88d6..3c279d155 100644 --- a/clinicadl/trainer/trainer.py +++ b/clinicadl/trainer/trainer.py @@ -33,6 +33,7 @@ patch_to_read_json, ) from clinicadl.trainer.tasks_utils import create_training_config +from clinicadl.validator.validator import Validator if TYPE_CHECKING: from clinicadl.callbacks.callbacks import Callback @@ -43,8 +44,6 @@ generate_sampler, get_criterion, save_outputs, - test, - test_da, ) logger = getLogger("clinicadl.trainer") @@ -64,6 +63,7 @@ def __init__( """ self.config = config self.maps_manager = self._init_maps_manager(config) + self.validator = Validator() self._check_args() def _init_maps_manager(self, config) -> MapsManager: @@ -371,12 +371,14 @@ def _train_single( ) if cluster.master: - self.maps_manager._ensemble_prediction( + self.validator._ensemble_prediction( + self.maps_manager, "train", split, self.config.validation.selection_metrics, ) - self.maps_manager._ensemble_prediction( + self.validator._ensemble_prediction( + self.maps_manager, "validation", split, self.config.validation.selection_metrics, @@ -495,12 +497,14 @@ def _train_multi( resume = False if cluster.master: - self.maps_manager._ensemble_prediction( + self.validator._ensemble_prediction( + self.maps_manager, "train", split, self.config.validation.selection_metrics, ) - self.maps_manager._ensemble_prediction( + self.validator._ensemble_prediction( + self.maps_manager, "validation", split, self.config.validation.selection_metrics, @@ -706,12 +710,14 @@ def _train_ssda( resume=resume, ) - self.maps_manager._ensemble_prediction( + self.validator._ensemble_prediction( + self.maps_manager, "train", split, self.config.validation.selection_metrics, ) - self.maps_manager._ensemble_prediction( + self.validator._ensemble_prediction( + self.maps_manager, "validation", split, self.config.validation.selection_metrics, @@ -861,7 +867,7 @@ def _train( ): evaluation_flag = False - _, metrics_train = test( + _, metrics_train = self.validator.test( mode=self.maps_manager.mode, metrics_module=self.maps_manager.metrics_module, n_classes=self.maps_manager.n_classes, @@ -871,7 +877,7 @@ def _train( criterion=criterion, amp=self.maps_manager.std_amp, ) - _, metrics_valid = test( + _, metrics_valid = self.validator.test( mode=self.maps_manager.mode, metrics_module=self.maps_manager.metrics_module, n_classes=self.maps_manager.n_classes, @@ -928,7 +934,7 @@ def _train( model.zero_grad(set_to_none=True) logger.debug(f"Last checkpoint at the end of the epoch {epoch}") - _, metrics_train = test( + _, metrics_train = self.validator.test( mode=self.maps_manager.mode, metrics_module=self.maps_manager.metrics_module, n_classes=self.maps_manager.n_classes, @@ -938,7 +944,7 @@ def _train( criterion=criterion, amp=self.maps_manager.std_amp, ) - _, metrics_valid = test( + _, metrics_valid = self.validator.test( mode=self.maps_manager.mode, metrics_module=self.maps_manager.metrics_module, n_classes=self.maps_manager.n_classes, @@ -998,7 +1004,8 @@ def _train( epoch += 1 del model - self.maps_manager._test_loader( + self.validator._test_loader( + self.maps_manager, train_loader, criterion, "train", @@ -1007,7 +1014,8 @@ def _train( amp=self.maps_manager.std_amp, network=network, ) - self.maps_manager._test_loader( + self.validator._test_loader( + self.maps_manager, valid_loader, criterion, "validation", @@ -1018,7 +1026,8 @@ def _train( ) if save_outputs(self.maps_manager.network_task): - self.maps_manager._compute_output_tensors( + self.validator._compute_output_tensors( + self.maps_manager, train_loader.dataset, "train", split, @@ -1026,7 +1035,8 @@ def _train( nb_images=1, network=network, ) - self.maps_manager._compute_output_tensors( + self.validator._compute_output_tensors( + self.maps_manager, valid_loader.dataset, "validation", split, @@ -1400,7 +1410,8 @@ def _train_ssdann( epoch += 1 - self.maps_manager._test_loader_ssda( + self.validator._test_loader_ssda( + self.maps_manager, train_target_loader, criterion, data_group="train", @@ -1410,7 +1421,8 @@ def _train_ssdann( target=True, alpha=0, ) - self.maps_manager._test_loader_ssda( + self.validator._test_loader_ssda( + self.maps_manager, valid_loader, criterion, data_group="validation", @@ -1422,7 +1434,8 @@ def _train_ssdann( ) if save_outputs(self.maps_manager.network_task): - self.maps_manager._compute_output_tensors( + self.validator._compute_output_tensors( + self.maps_manager, train_target_loader.dataset, "train", split, @@ -1430,7 +1443,8 @@ def _train_ssdann( nb_images=1, network=network, ) - self.maps_manager._compute_output_tensors( + self.validator._compute_output_tensors( + self.maps_manager, train_target_loader.dataset, "validation", split, diff --git a/clinicadl/validator/config.py b/clinicadl/validator/config.py new file mode 100644 index 000000000..165b36dd0 --- /dev/null +++ b/clinicadl/validator/config.py @@ -0,0 +1,48 @@ +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Optional, Union + +from pydantic import ( + BaseModel, + ConfigDict, + computed_field, + field_validator, +) + +from clinicadl.utils.factories import DefaultFromLibrary + + +class ValidatorConfig(BaseModel): + """Base config class to configure the validator.""" + + maps_path: Path + mode: str + network_task: str + split_name: Optional[str] = None + num_networks: Optional[int] = None + fsdp: Optional[bool] = None + amp: Optional[bool] = None + metrics_module: Optional = None + n_classes: Optional[int] = None + nb_unfrozen_layers: Optional[int] = None + std_amp: Optional[bool] = None + + # pydantic config + model_config = ConfigDict( + validate_assignment=True, + use_enum_values=True, + validate_default=True, + ) + + @computed_field + @property + @abstractmethod + def metric(self) -> str: + """The name of the metric.""" + + @field_validator("get_not_nans", mode="after") + @classmethod + def validator_get_not_nans(cls, v): + assert not v, "get_not_nans not supported in ClinicaDL. Please set to False." + + return v diff --git a/clinicadl/validator/validator.py b/clinicadl/validator/validator.py new file mode 100644 index 000000000..d55810299 --- /dev/null +++ b/clinicadl/validator/validator.py @@ -0,0 +1,498 @@ +from logging import getLogger +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import pandas as pd +import torch +import torch.distributed as dist +from torch.amp import autocast +from torch.nn.modules.loss import _Loss +from torch.utils.data import DataLoader + +from clinicadl.maps_manager.maps_manager import MapsManager +from clinicadl.metrics.metric_module import MetricModule +from clinicadl.metrics.utils import find_selection_metrics +from clinicadl.network.network import Network +from clinicadl.trainer.tasks_utils import columns, compute_metrics, generate_test_row +from clinicadl.utils import cluster +from clinicadl.utils.computational.ddp import DDP, init_ddp +from clinicadl.utils.enum import ( + ClassificationLoss, + ClassificationMetric, + ReconstructionLoss, + ReconstructionMetric, + RegressionLoss, + RegressionMetric, + Task, +) +from clinicadl.utils.exceptions import ( + ClinicaDLArgumentError, + ClinicaDLConfigurationError, + MAPSError, +) + +logger = getLogger("clinicadl.maps_manager") +level_list: List[str] = ["warning", "info", "debug"] + + +# TODO save weights on CPU for better compatibility + + +class Validator: + def test( + self, + mode: str, + metrics_module: MetricModule, + n_classes: int, + network_task, + model: Network, + dataloader: DataLoader, + criterion: _Loss, + use_labels: bool = True, + amp: bool = False, + report_ci=False, + ) -> Tuple[pd.DataFrame, Dict[str, float]]: + """ + Computes the predictions and evaluation metrics. + + Parameters + ---------- + model: Network + The model trained. + dataloader: DataLoader + Wrapper of a CapsDataset. + criterion: _Loss + Function to calculate the loss. + use_labels: bool + If True the true_label will be written in output DataFrame + and metrics dict will be created. + amp: bool + If True, enables Pytorch's automatic mixed precision. + + Returns + ------- + the results and metrics on the image level. + """ + model.eval() + dataloader.dataset.eval() + + results_df = pd.DataFrame(columns=columns(network_task, mode, n_classes)) + total_loss = {} + with torch.no_grad(): + for i, data in enumerate(dataloader): + # initialize the loss list to save the loss components + with autocast("cuda", enabled=amp): + outputs, loss_dict = model(data, criterion, use_labels=use_labels) + + if i == 0: + for loss_component in loss_dict.keys(): + total_loss[loss_component] = 0 + for loss_component in total_loss.keys(): + total_loss[loss_component] += loss_dict[loss_component].float() + + # Generate detailed DataFrame + for idx in range(len(data["participant_id"])): + row = generate_test_row( + network_task, + mode, + metrics_module, + n_classes, + idx, + data, + outputs.float(), + ) + row_df = pd.DataFrame( + row, columns=columns(network_task, mode, n_classes) + ) + results_df = pd.concat([results_df, row_df]) + + del outputs, loss_dict + dataframes = [None] * dist.get_world_size() + dist.gather_object( + results_df, dataframes if dist.get_rank() == 0 else None, dst=0 + ) + if dist.get_rank() == 0: + results_df = pd.concat(dataframes) + del dataframes + results_df.reset_index(inplace=True, drop=True) + + if not use_labels: + metrics_dict = None + else: + metrics_dict = compute_metrics( + network_task, results_df, metrics_module, report_ci=report_ci + ) + for loss_component in total_loss.keys(): + dist.reduce(total_loss[loss_component], dst=0) + loss_value = total_loss[loss_component].item() / cluster.world_size + + if report_ci: + metrics_dict["Metric_names"].append(loss_component) + metrics_dict["Metric_values"].append(loss_value) + metrics_dict["Lower_CI"].append("N/A") + metrics_dict["Upper_CI"].append("N/A") + metrics_dict["SE"].append("N/A") + + else: + metrics_dict[loss_component] = loss_value + + torch.cuda.empty_cache() + + return results_df, metrics_dict + + def test_da( + self, + mode: str, + metrics_module: MetricModule, + n_classes: int, + network_task: Union[str, Task], + model: Network, + dataloader: DataLoader, + criterion: _Loss, + alpha: float = 0, + use_labels: bool = True, + target: bool = True, + report_ci=False, + ) -> Tuple[pd.DataFrame, Dict[str, float]]: + """ + Computes the predictions and evaluation metrics. + + Args: + model: the model trained. + dataloader: wrapper of a CapsDataset. + criterion: function to calculate the loss. + use_labels: If True the true_label will be written in output DataFrame + and metrics dict will be created. + Returns: + the results and metrics on the image level. + """ + model.eval() + dataloader.dataset.eval() + results_df = pd.DataFrame(columns=columns(network_task, mode, n_classes)) + total_loss = 0 + with torch.no_grad(): + for i, data in enumerate(dataloader): + outputs, loss_dict = model.compute_outputs_and_loss_test( + data, criterion, alpha, target + ) + total_loss += loss_dict["loss"].item() + + # Generate detailed DataFrame + for idx in range(len(data["participant_id"])): + row = generate_test_row( + network_task, + mode, + metrics_module, + n_classes, + idx, + data, + outputs, + ) + row_df = pd.DataFrame( + row, columns=columns(network_task, mode, n_classes) + ) + results_df = pd.concat([results_df, row_df]) + + del outputs, loss_dict + results_df.reset_index(inplace=True, drop=True) + + if not use_labels: + metrics_dict = None + else: + metrics_dict = compute_metrics( + network_task, results_df, metrics_module, report_ci=report_ci + ) + if report_ci: + metrics_dict["Metric_names"].append("loss") + metrics_dict["Metric_values"].append(total_loss) + metrics_dict["Lower_CI"].append("N/A") + metrics_dict["Upper_CI"].append("N/A") + metrics_dict["SE"].append("N/A") + + else: + metrics_dict["loss"] = total_loss + + torch.cuda.empty_cache() + + return results_df, metrics_dict + + def _test_loader( + self, + maps_manager: MapsManager, + dataloader, + criterion, + data_group: str, + split: int, + selection_metrics, + use_labels=True, + gpu=None, + amp=False, + network=None, + report_ci=True, + ): + """ + Launches the testing task on a dataset wrapped by a DataLoader and writes prediction TSV files. + + Args: + dataloader (torch.utils.data.DataLoader): DataLoader wrapping the test CapsDataset. + criterion (torch.nn.modules.loss._Loss): optimization criterion used during training. + data_group (str): name of the data group used for the testing task. + split (int): Index of the split used to train the model tested. + selection_metrics (list[str]): List of metrics used to select the best models which are tested. + use_labels (bool): If True, the labels must exist in test meta-data and metrics are computed. + gpu (bool): If given, a new value for the device of the model will be computed. + amp (bool): If enabled, uses Automatic Mixed Precision (requires GPU usage). + network (int): Index of the network tested (only used in multi-network setting). + """ + for selection_metric in selection_metrics: + if cluster.master: + log_dir = ( + maps_manager.maps_path + / f"{maps_manager.split_name}-{split}" + / f"best-{selection_metric}" + / data_group + ) + maps_manager.write_description_log( + log_dir, + data_group, + dataloader.dataset.config.data.caps_dict, + dataloader.dataset.config.data.data_df, + ) + + # load the best trained model during the training + model, _ = maps_manager._init_model( + transfer_path=maps_manager.maps_path, + split=split, + transfer_selection=selection_metric, + gpu=gpu, + network=network, + ) + model = DDP( + model, + fsdp=maps_manager.fully_sharded_data_parallel, + amp=maps_manager.amp, + ) + + prediction_df, metrics = self.test( + mode=maps_manager.mode, + metrics_module=maps_manager.metrics_module, + n_classes=maps_manager.n_classes, + network_task=maps_manager.network_task, + model=model, + dataloader=dataloader, + criterion=criterion, + use_labels=use_labels, + amp=amp, + report_ci=report_ci, + ) + if use_labels: + if network is not None: + metrics[f"{maps_manager.mode}_id"] = network + + loss_to_log = ( + metrics["Metric_values"][-1] if report_ci else metrics["loss"] + ) + + logger.info( + f"{maps_manager.mode} level {data_group} loss is {loss_to_log} for model selected on {selection_metric}" + ) + + if cluster.master: + # Replace here + maps_manager._mode_level_to_tsv( + prediction_df, + metrics, + split, + selection_metric, + data_group=data_group, + ) + + def _test_loader_ssda( + self, + maps_manager: MapsManager, + dataloader, + criterion, + alpha, + data_group, + split, + selection_metrics, + use_labels=True, + gpu=None, + network=None, + target=False, + report_ci=True, + ): + """ + Launches the testing task on a dataset wrapped by a DataLoader and writes prediction TSV files. + + Args: + dataloader (torch.utils.data.DataLoader): DataLoader wrapping the test CapsDataset. + criterion (torch.nn.modules.loss._Loss): optimization criterion used during training. + data_group (str): name of the data group used for the testing task. + split (int): Index of the split used to train the model tested. + selection_metrics (list[str]): List of metrics used to select the best models which are tested. + use_labels (bool): If True, the labels must exist in test meta-data and metrics are computed. + gpu (bool): If given, a new value for the device of the model will be computed. + network (int): Index of the network tested (only used in multi-network setting). + """ + for selection_metric in selection_metrics: + log_dir = ( + maps_manager.maps_path + / f"{maps_manager.split_name}-{split}" + / f"best-{selection_metric}" + / data_group + ) + maps_manager.write_description_log( + log_dir, + data_group, + dataloader.dataset.caps_dict, + dataloader.dataset.df, + ) + + # load the best trained model during the training + model, _ = maps_manager._init_model( + transfer_path=maps_manager.maps_path, + split=split, + transfer_selection=selection_metric, + gpu=gpu, + network=network, + ) + prediction_df, metrics = self.test_da( + network_task=maps_manager.network_task, + model=model, + dataloader=dataloader, + criterion=criterion, + target=target, + report_ci=report_ci, + mode=maps_manager.mode, + metrics_module=maps_manager.metrics_module, + n_classes=maps_manager.n_classes, + ) + if use_labels: + if network is not None: + metrics[f"{maps_manager.mode}_id"] = network + + if report_ci: + loss_to_log = metrics["Metric_values"][-1] + else: + loss_to_log = metrics["loss"] + + logger.info( + f"{maps_manager.mode} level {data_group} loss is {loss_to_log} for model selected on {selection_metric}" + ) + + # Replace here + maps_manager._mode_level_to_tsv( + prediction_df, metrics, split, selection_metric, data_group=data_group + ) + + @torch.no_grad() + def _compute_output_tensors( + self, + maps_manager: MapsManager, + dataset, + data_group, + split, + selection_metrics, + nb_images=None, + gpu=None, + network=None, + ): + """ + Compute the output tensors and saves them in the MAPS. + + Args: + dataset (clinicadl.caps_dataset.data.CapsDataset): wrapper of the data set. + data_group (str): name of the data group used for the task. + split (int): split number. + selection_metrics (list[str]): metrics used for model selection. + nb_images (int): number of full images to write. Default computes the outputs of the whole data set. + gpu (bool): If given, a new value for the device of the model will be computed. + network (int): Index of the network tested (only used in multi-network setting). + """ + for selection_metric in selection_metrics: + # load the best trained model during the training + model, _ = maps_manager._init_model( + transfer_path=maps_manager.maps_path, + split=split, + transfer_selection=selection_metric, + gpu=gpu, + network=network, + nb_unfrozen_layer=maps_manager.nb_unfrozen_layer, + ) + model = DDP( + model, + fsdp=maps_manager.fully_sharded_data_parallel, + amp=maps_manager.amp, + ) + model.eval() + + tensor_path = ( + maps_manager.maps_path + / f"{maps_manager.split_name}-{split}" + / f"best-{selection_metric}" + / data_group + / "tensors" + ) + if cluster.master: + tensor_path.mkdir(parents=True, exist_ok=True) + dist.barrier() + + if nb_images is None: # Compute outputs for the whole data set + nb_modes = len(dataset) + else: + nb_modes = nb_images * dataset.elem_per_image + + for i in [ + *range(cluster.rank, nb_modes, cluster.world_size), + *range(int(nb_modes % cluster.world_size <= cluster.rank)), + ]: + data = dataset[i] + image = data["image"] + x = image.unsqueeze(0).to(model.device) + with autocast("cuda", enabled=maps_manager.std_amp): + output = model(x) + output = output.squeeze(0).cpu().float() + participant_id = data["participant_id"] + session_id = data["session_id"] + mode_id = data[f"{maps_manager.mode}_id"] + input_filename = f"{participant_id}_{session_id}_{maps_manager.mode}-{mode_id}_input.pt" + output_filename = f"{participant_id}_{session_id}_{maps_manager.mode}-{mode_id}_output.pt" + torch.save(image, tensor_path / input_filename) + torch.save(output, tensor_path / output_filename) + logger.debug(f"File saved at {[input_filename, output_filename]}") + + def _ensemble_prediction( + self, + maps_manager: MapsManager, + data_group, + split, + selection_metrics, + use_labels=True, + skip_leak_check=False, + ): + """Computes the results on the image-level.""" + + if not selection_metrics: + selection_metrics = find_selection_metrics( + maps_manager.maps_path, maps_manager.split_name, split + ) + + for selection_metric in selection_metrics: + ##################### + # Soft voting + if maps_manager.num_networks > 1 and not skip_leak_check: + maps_manager._ensemble_to_tsv( + split, + selection=selection_metric, + data_group=data_group, + use_labels=use_labels, + ) + elif maps_manager.mode != "image" and not skip_leak_check: + maps_manager._mode_to_image_tsv( + split, + selection=selection_metric, + data_group=data_group, + use_labels=use_labels, + )