From d14b69624d5d0be8983a934eaa90ebfd56e1e261 Mon Sep 17 00:00:00 2001 From: camillebrianceau Date: Mon, 22 Jul 2024 10:17:40 +0200 Subject: [PATCH] tests --- clinicadl/caps_dataset/caps_dataset_config.py | 3 + .../commandline/pipelines/predict/cli.py | 4 +- clinicadl/predict/predict_manager.py | 127 +++++++++--------- clinicadl/trainer/config/reconstruction.py | 6 + 4 files changed, 77 insertions(+), 63 deletions(-) diff --git a/clinicadl/caps_dataset/caps_dataset_config.py b/clinicadl/caps_dataset/caps_dataset_config.py index 31187670e..c468a7ed8 100644 --- a/clinicadl/caps_dataset/caps_dataset_config.py +++ b/clinicadl/caps_dataset/caps_dataset_config.py @@ -2,6 +2,7 @@ from pathlib import Path from typing import Any, Dict, Optional, Tuple, Union +import pandas as pd from pydantic import BaseModel, ConfigDict from clinicadl.caps_dataset.data_config import DataConfig @@ -136,6 +137,8 @@ def from_data_group( config.data.caps_directory = caps_directory config.data.data_tsv = data_tsv + config.data.data_df = pd.read_csv(config.data.data_tsv, sep="\t") + return config @classmethod diff --git a/clinicadl/commandline/pipelines/predict/cli.py b/clinicadl/commandline/pipelines/predict/cli.py index 848208494..cd3641d0c 100644 --- a/clinicadl/commandline/pipelines/predict/cli.py +++ b/clinicadl/commandline/pipelines/predict/cli.py @@ -61,8 +61,10 @@ def cli(input_maps_directory, data_group, **kwargs): INPUT_MAPS_DIRECTORY is the MAPS folder from where the model used for prediction will be loaded. DATA_GROUP is the name of the subjects and sessions list used for the interpretation. """ - predictor = Predictor.from_maps(input_maps_directory) + print(kwargs["gpu"]) + predictor = Predictor.from_maps(input_maps_directory, **kwargs) print(predictor) + caps_config = CapsDatasetConfig.from_data_group( input_maps_directory, data_group, **kwargs ) diff --git a/clinicadl/predict/predict_manager.py b/clinicadl/predict/predict_manager.py index c0d665409..3a28c9fd5 100644 --- a/clinicadl/predict/predict_manager.py +++ b/clinicadl/predict/predict_manager.py @@ -55,7 +55,9 @@ def __init__(self, _config: TrainConfig) -> None: self.maps_manager = MapsManager(maps_path, parameters, verbose=None) @classmethod - def from_json(cls, config_file: Union[str, Path], maps_path: Union[str, Path]): + def from_json( + cls, config_file: Union[str, Path], maps_path: Union[str, Path], **kwargs + ): """ Creates a Trainer from a json configuration file. @@ -82,13 +84,14 @@ def from_json(cls, config_file: Union[str, Path], maps_path: Union[str, Path]): raise FileNotFoundError(f"No file found at {str(config_file)}.") config_dict = patch_to_read_json(read_json(config_file)) # TODO : remove patch config_dict["maps_dir"] = maps_path + config_dict.update(kwargs) config_object = create_training_config(config_dict["network_task"])( **config_dict ) return cls(config_object) @classmethod - def from_maps(cls, maps_path: Union[str, Path]): + def from_maps(cls, maps_path: Union[str, Path], **kwargs): """ Creates a Trainer from a json configuration file. @@ -114,7 +117,7 @@ def from_maps(cls, maps_path: Union[str, Path]): f"MAPS was not found at {str(maps_path)}." f"To initiate a new MAPS please give a train_dict." ) - return cls.from_json(maps_path / "maps.json", maps_path) + return cls.from_json(maps_path / "maps.json", maps_path, **kwargs) def predict( self, @@ -253,19 +256,22 @@ def predict( ) self._check_data_group(data_group, caps_config) + print(caps_config) criterion = self.maps_manager.task_manager.get_criterion(self.maps_manager.loss) - self._check_data_group(df=group_df) + # self._check_data_group(df=group_df) - assert self.config.split # don't know if needed ? try to raise an exception ? + assert ( + self.config.cross_validation.split + ) # don't know if needed ? try to raise an exception ? # assert self.config.data.label - for split in self.config.split: + for split in self.config.cross_validation.split: logger.info(f"Prediction of split {split}") - group_df, group_parameters = self.get_group_info( - self.config.data.data_group, split - ) + group_df, group_parameters = self.get_group_info(data_group, split) # Find label code if not given - if self.config.is_given_label_code(self.maps_manager.label, label_code): + if self.config.data.is_given_label_code( + self.maps_manager.label, label_code + ): self.maps_manager.task_manager.generate_label_code( group_df, self.config.data.label ) @@ -281,12 +287,12 @@ def predict( self.maps_manager.maps_path / f"{self.maps_manager.split_name}-{split}" / f"best-{selection}" - / self.config.data.data_group + / data_group ) - tsv_pattern = f"{self.config.data.data_group}*.tsv" + tsv_pattern = f"{data_group}*.tsv" for tsv_file in tsv_dir.glob(tsv_pattern): tsv_file.unlink() - self.config.check_label(self.maps_manager.label) + self.config.data.check_label(self.maps_manager.label) if self.maps_manager.multi_network: self._predict_multi( group_parameters, @@ -309,11 +315,11 @@ def predict( ) if cluster.master: self.maps_manager._ensemble_prediction( - self.config.data.data_group, + data_group, split, self.config.validation.selection_metrics, self.config.data.use_labels, - self.config.skip_leak_check, + self.config.validation.skip_leak_check, ) def _predict_multi( @@ -954,60 +960,57 @@ def _check_data_group( group_dir = self.maps_manager.maps_path / "groups" / data_group logger.debug(f"Group path {group_dir}") if group_dir.is_dir(): # Data group already exists - if self.config.maps_manager.overwrite: - if data_group in ["train", "validation"]: - raise MAPSError("Cannot overwrite train or validation data group.") - else: - print(self.config.cross_validation.split) - if not self.config.cross_validation.split: - self.config.cross_validation.split = ( - self.maps_manager.find_splits() - ) - - print(self.config.cross_validation.split) - # assert self.config.split - for split in self.config.cross_validation.split: - selection_metrics = self.maps_manager._find_selection_metrics( - split - ) - for selection in selection_metrics: - results_path = ( - self.maps_manager.maps_path - / f"{self.maps_manager.split_name}-{split}" - / f"best-{selection}" - / data_group - ) - if results_path.is_dir(): - shutil.rmtree(results_path) - elif df is not None or ( - caps_config.caps_directory is not None - and self.config.caps_directory != Path("") - ): - raise ClinicaDLArgumentError( - f"Data group {data_group} is already defined. " - f"Please do not give any caps_directory, tsv_path or multi_cohort to use it. " - f"To erase {data_group} please set overwrite to True." - ) - - elif not group_dir.is_dir() and ( - self.config.caps_directory is None or df is None - ): # Data group does not exist yet / was overwritten + missing data - raise ClinicaDLArgumentError( - f"The data group {self.config.data.data_group} does not already exist. " - f"Please specify a caps_directory and a tsv_path to create this data group." - ) + # if self.config.maps_manager.overwrite: + # if data_group in ["train", "validation"]: + # raise MAPSError("Cannot overwrite train or validation data group.") + # else: + print("cross validation_split", self.config.cross_validation.split) + if not self.config.cross_validation.split: + self.config.cross_validation.split = self.maps_manager.find_splits() + print("cross validation_split", self.config.cross_validation.split) + # assert self.config.split + + # IF there is a dir + for split in self.config.cross_validation.split: + selection_metrics = self.maps_manager._find_selection_metrics(split) + for selection in selection_metrics: + results_path = ( + self.maps_manager.maps_path + / f"{self.maps_manager.split_name}-{split}" + / f"best-{selection}" + / data_group + ) + if results_path.is_dir() and self.config.maps_manager.overwrite: + shutil.rmtree(results_path) + # elif df is not None or ( + # caps_config.caps_directory is not None + # and self.config.caps_directory != Path("") + # ): + # raise ClinicaDLArgumentError( + # f"Data group {data_group} is already defined. " + # f"Please do not give any caps_directory, tsv_path or multi_cohort to use it. " + # f"To erase {data_group} please set overwrite to True." + # ) + + # elif not group_dir.is_dir() and ( + # self.config.caps_directory is None or df is None + # ): # Data group does not exist yet / was overwritten + missing data + # raise ClinicaDLArgumentError( + # f"The data group {self.config.data.data_group} does not already exist. " + # f"Please specify a caps_directory and a tsv_path to create this data group." + # ) elif ( not group_dir.is_dir() ): # Data group does not exist yet / was overwritten + all data is provided - if self.config.skip_leak_check: + if self.config.validation.skip_leak_check: logger.info("Skipping data leakage check") else: - self._check_leakage(self.config.data.data_group, df) + self._check_leakage(data_group, caps_config.data.data_df) self._write_data_group( - self.config.data.data_group, + data_group, df, - self.config.caps_directory, - self.config.multi_cohort, + caps_config.data.caps_directory, + caps_config.data.multi_cohort, label=self.config.data.label, ) diff --git a/clinicadl/trainer/config/reconstruction.py b/clinicadl/trainer/config/reconstruction.py index 4ad9d5927..48e5a0f7a 100644 --- a/clinicadl/trainer/config/reconstruction.py +++ b/clinicadl/trainer/config/reconstruction.py @@ -33,6 +33,12 @@ class NetworkConfig(BaseNetworkConfig): # TODO : put in model module def validator_architecture(cls, v): return v # TODO : connect to network module to have list of available architectures + @field_validator("normalization", mode="before") + def validator_normalization(cls, v): + if v == "batch": + v = "BatchNorm" + return v + class ValidationConfig(BaseValidationConfig): """Config class for the validation procedure in reconstruction mode."""