diff --git a/clinicadl/maps_manager/maps_manager.py b/clinicadl/maps_manager/maps_manager.py index c2d2c9196..3b32486b5 100644 --- a/clinicadl/maps_manager/maps_manager.py +++ b/clinicadl/maps_manager/maps_manager.py @@ -12,15 +12,12 @@ from clinicadl.caps_dataset.data import ( return_dataset, ) -from clinicadl.maps_manager.tmp_config import TmpConfig from clinicadl.metrics.metric_module import MetricModule from clinicadl.metrics.utils import ( check_selection_metric, ) from clinicadl.predict.utils import get_prediction -from clinicadl.splitter.split_utils import init_split_manager from clinicadl.trainer.tasks_utils import ( - create_training_config, ensemble_prediction, evaluation_metrics, generate_label_code, @@ -34,7 +31,6 @@ ClinicaDLConfigurationError, MAPSError, ) -from clinicadl.utils.iotools.data_utils import load_data_test from clinicadl.utils.iotools.maps_manager_utils import ( add_default_values, ) @@ -74,89 +70,194 @@ def __init__( f"MAPS was not found at {maps_path}." f"To initiate a new MAPS please give a train_dict." ) + test_parameters = self.get_parameters() + # test_parameters = path_decoder(test_parameters) + # from clinicadl.trainer.task_manager import TaskConfig - test_parameters = self.read_maps_json() - config = TmpConfig(**test_parameters) - self.init_existing_maps(config) + self.parameters = add_default_values(test_parameters) - # Initiate MAPS - else: - print(parameters) - config = TmpConfig(**parameters) - self.init_new_maps(config) - - self.config = config - init_ddp(gpu=self.config.gpu, logger=logger) - - def init_existing_maps(self, config: TmpConfig): - config.n_classes = config.output_size - if config.network_task == "classification": - if config.n_classes is None: - config.n_classes = output_size( - config.network_task, None, config.data_df, config.label - ) - config.metrics_module = MetricModule( - evaluation_metrics(config.network_task), n_classes=config.n_classes - ) + ## to initialize the task parameters - elif ( - config.network_task == "regression" - or config.network_task == "reconstruction" - ): - config.metrics_module = MetricModule( - evaluation_metrics(config.network_task), n_classes=None - ) + # self.task_manager = self._init_task_manager() - else: - raise NotImplementedError( - f"Task {config.network_task} is not implemented in ClinicaDL. " - f"Please choose between classification, regression and reconstruction." - ) + self.n_classes = self.output_size + if self.network_task == "classification": + if self.n_classes is None: + self.n_classes = output_size( + self.network_task, None, self.df, self.label + ) + self.metrics_module = MetricModule( + evaluation_metrics(self.network_task), n_classes=self.n_classes + ) - config.split_name = ( - self._check_split_wording() - ) # Used only for retro-compatibility - - def init_new_maps(self, config: TmpConfig): - config.check_args() - config.split_name = "split" # Used only for retro-compatibility - if cluster.master: - if ( - config.maps_path.is_dir() and config.maps_path.is_file() - ) or ( # Non-folder file - config.maps_path.is_dir() - and list(config.maps_path.iterdir()) # Non empty folder + elif ( + self.network_task == "regression" + or self.network_task == "reconstruction" ): - raise MAPSError( - f"You are trying to create a new MAPS at {config.maps_path} but " - f"this already corresponds to a file or a non-empty folder. \n" - f"Please remove it or choose another location." + self.metrics_module = MetricModule( + evaluation_metrics(self.network_task), n_classes=None + ) + + else: + raise NotImplementedError( + f"Task {self.network_task} is not implemented in ClinicaDL. " + f"Please choose between classification, regression and reconstruction." ) - (config.maps_path / "groups").mkdir(parents=True) - logger.info(f"A new MAPS was created at {config.maps_path}") - self.write_parameters(config.maps_path, dict(config.model_dump())) - self._write_requirements_version() - self._write_training_data() - self._write_train_val_groups() - self._write_information() + self.split_name = ( + self._check_split_wording() + ) # Used only for retro-compatibility + + # Initiate MAPS + else: + self._check_args(parameters) + parameters["tsv_path"] = Path(parameters["tsv_path"]) + + self.split_name = "split" # Used only for retro-compatibility + if cluster.master: + if (maps_path.is_dir() and maps_path.is_file()) or ( # Non-folder file + maps_path.is_dir() and list(maps_path.iterdir()) # Non empty folder + ): + raise MAPSError( + f"You are trying to create a new MAPS at {maps_path} but " + f"this already corresponds to a file or a non-empty folder. \n" + f"Please remove it or choose another location." + ) + (maps_path / "groups").mkdir(parents=True) + + logger.info(f"A new MAPS was created at {maps_path}") + self.write_parameters(self.maps_path, self.parameters) + self._write_requirements_version() + self._write_training_data() + self._write_train_val_groups() + self._write_information() + + init_ddp(gpu=self.parameters["gpu"], logger=logger) def __getattr__(self, name): """Allow to directly get the values in parameters attribute""" - if name in self.config.model_dump(): - return self.config.model_dump()[name] + if name in self.parameters: + return self.parameters[name] else: raise AttributeError(f"'MapsManager' object has no attribute '{name}'") + ################################### + # High-level functions templates # + ################################### + ############################### # Checks # ############################### + def _check_args(self, parameters): + """ + Check the training parameters integrity + """ + logger.debug("Checking arguments...") + mandatory_arguments = [ + "caps_directory", + "tsv_path", + "preprocessing_dict", + "mode", + "network_task", + ] + for arg in mandatory_arguments: + if arg not in parameters: + raise ClinicaDLArgumentError( + f"The values of mandatory arguments {mandatory_arguments} should be set. " + f"No value was given for {arg}." + ) + self.parameters = add_default_values(parameters) + + transfo_config = TransformsConfig( + normalize=self.normalize, + size_reduction=self.size_reduction, + size_reduction_factor=self.size_reduction_factor, + ) + + split_manager = self._init_split_manager(None) + train_df = split_manager[0]["train"] + if "label" not in self.parameters: + self.parameters["label"] = None + + from clinicadl.trainer.tasks_utils import ( + get_default_network, + ) + from clinicadl.utils.enum import Task + + self.network_task = Task(self.parameters["network_task"]) + # self.task_config = TaskConfig(self.network_task, self.mode, df=train_df) + # self.task_manager = self._init_task_manager(df=train_df) + if self.network_task == "classification": + self.n_classes = output_size(self.network_task, None, train_df, self.label) + self.metrics_module = MetricModule( + evaluation_metrics(self.network_task), n_classes=self.n_classes + ) + + elif self.network_task == "regression" or self.network_task == "reconstruction": + self.n_classes = None + self.metrics_module = MetricModule( + evaluation_metrics(self.network_task), n_classes=None + ) + + else: + raise NotImplementedError( + f"Task {self.network_task} is not implemented in ClinicaDL. " + f"Please choose between classification, regression and reconstruction." + ) + if self.parameters["architecture"] == "default": + self.parameters["architecture"] = get_default_network(self.network_task) + if "selection_threshold" not in self.parameters: + self.parameters["selection_threshold"] = None + if ( + "label_code" not in self.parameters + or len(self.parameters["label_code"]) == 0 + or self.parameters["label_code"] is None + ): # Allows to set custom label code in TOML + self.parameters["label_code"] = generate_label_code( + self.network_task, train_df, self.label + ) + + full_dataset = return_dataset( + self.caps_directory, + train_df, + self.preprocessing_dict, + multi_cohort=self.multi_cohort, + label=self.label, + label_code=self.parameters["label_code"], + transforms_config=transfo_config, + ) + self.parameters.update( + { + "num_networks": full_dataset.elem_per_image, + "output_size": output_size( + self.network_task, full_dataset.size, full_dataset.df, self.label + ), + "input_size": full_dataset.size, + } + ) + + if self.parameters["num_networks"] < 2 and self.multi_network: + raise ClinicaDLConfigurationError( + f"Invalid training configuration: cannot train a multi-network " + f"framework with only {self.parameters['num_networks']} element " + f"per image." + ) + possible_selection_metrics_set = set(evaluation_metrics(self.network_task)) | { + "loss" + } + if not set(self.parameters["selection_metrics"]).issubset( + possible_selection_metrics_set + ): + raise ClinicaDLConfigurationError( + f"Selection metrics {self.parameters['selection_metrics']} " + f"must be a subset of metrics used for evaluation " + f"{possible_selection_metrics_set}." + ) - # TODO: To put elsewhere def _check_split_wording(self): """Finds if MAPS structure uses 'fold-X' or 'split-X' folders.""" - if len(list(self.config.maps_path.glob("fold-*"))) > 0: + if len(list(self.maps_path.glob("fold-*"))) > 0: return "fold" else: return "split" @@ -187,7 +288,7 @@ def _write_requirements_version(self): env_variables = subprocess.check_output("pip freeze", shell=True).decode( "utf-8" ) - with (self.config.maps_path / "environment.txt").open(mode="w") as file: + with (self.maps_path / "environment.txt").open(mode="w") as file: file.write(env_variables) except subprocess.CalledProcessError: logger.warning( @@ -197,52 +298,49 @@ def _write_requirements_version(self): def _write_training_data(self): """Writes the TSV file containing the participant and session IDs used for training.""" logger.debug("Writing training data...") + from clinicadl.utils.iotools.data_utils import load_data_test train_df = load_data_test( - self.config.tsv_path, - self.config.diagnoses, + self.tsv_path, + self.diagnoses, baseline=False, - multi_cohort=self.config.multi_cohort, + multi_cohort=self.multi_cohort, ) train_df = train_df[["participant_id", "session_id"]] - if self.config.transfer_path: - transfer_train_path = ( - self.config.transfer_path / "groups" / "train+validation.tsv" - ) + if self.transfer_path: + transfer_train_path = self.transfer_path / "groups" / "train+validation.tsv" transfer_train_df = pd.read_csv(transfer_train_path, sep="\t") transfer_train_df = transfer_train_df[["participant_id", "session_id"]] train_df = pd.concat([train_df, transfer_train_df]) train_df.drop_duplicates(inplace=True) train_df.to_csv( - self.config.maps_path / "groups" / "train+validation.tsv", - sep="\t", - index=False, + self.maps_path / "groups" / "train+validation.tsv", sep="\t", index=False ) def _write_train_val_groups(self): """Defines the training and validation groups at the initialization""" logger.debug("Writing training and validation groups...") - split_manager = init_split_manager() + split_manager = self._init_split_manager() for split in split_manager.split_iterator(): for data_group in ["train", "validation"]: df = split_manager[split][data_group] group_path = ( - self.config.maps_path + self.maps_path / "groups" / data_group - / f"{self.config.split_name}-{split}" + / f"{self.split_name}-{split}" ) group_path.mkdir(parents=True, exist_ok=True) columns = ["participant_id", "session_id", "cohort"] - if self.config.label is not None: - columns.append(self.config.label) + if self.label is not None: + columns.append(self.label) df.to_csv(group_path / "data.tsv", sep="\t", columns=columns) self.write_parameters( group_path, { - "caps_directory": self.config.caps_directory, - "multi_cohort": self.config.multi_cohort, + "caps_directory": self.caps_directory, + "multi_cohort": self.multi_cohort, }, verbose=False, ) @@ -255,7 +353,7 @@ def _write_information(self): import clinicadl.network as network_package - model_class = getattr(network_package, self.config.architecture) + model_class = getattr(network_package, self.architecture) args = list( model_class.__init__.__code__.co_varnames[ : model_class.__init__.__code__.co_argcount @@ -264,16 +362,16 @@ def _write_information(self): args.remove("self") kwargs = dict() for arg in args: - kwargs[arg] = self.config.model_dump()[arg] + kwargs[arg] = self.parameters[arg] kwargs["gpu"] = False model = model_class(**kwargs) file_name = "information.log" - with (self.config.maps_path / file_name).open(mode="w") as f: + with (self.maps_path / file_name).open(mode="w") as f: f.write(f"- Date :\t{datetime.now().strftime('%d %b %Y, %H:%M:%S')}\n\n") - f.write(f"- Path :\t{self.config.maps_path}\n\n") + f.write(f"- Path :\t{self.maps_path}\n\n") # f.write("- Job ID :\t{}\n".format(os.getenv('SLURM_JOBID'))) f.write(f"- Model :\t{model.layers}\n\n") @@ -322,14 +420,14 @@ def _mode_level_to_tsv( data_group: the name referring to the data group on which evaluation is performed. """ performance_dir = ( - self.config.maps_path - / f"{self.config.split_name}-{split}" + self.maps_path + / f"{self.split_name}-{split}" / f"best-{selection}" / data_group ) performance_dir.mkdir(parents=True, exist_ok=True) performance_path = ( - performance_dir / f"{data_group}_{self.config.mode}_level_prediction.tsv" + performance_dir / f"{data_group}_{self.mode}_level_prediction.tsv" ) if not performance_path.is_file(): results_df.to_csv(performance_path, index=False, sep="\t") @@ -338,13 +436,11 @@ def _mode_level_to_tsv( performance_path, index=False, sep="\t", mode="a", header=False ) - metrics_path = ( - performance_dir / f"{data_group}_{self.config.mode}_level_metrics.tsv" - ) + metrics_path = performance_dir / f"{data_group}_{self.mode}_level_metrics.tsv" if metrics is not None: # if data_group == "train" or data_group == "validation": # pd_metrics = pd.DataFrame(metrics, index = [0]) - # header = Trueconfig. + # header = True # else: # pd_metrics = pd.DataFrame(metrics).T # header = False @@ -383,27 +479,27 @@ def _ensemble_to_tsv( else: validation_dataset = "validation" test_df = get_prediction( - self.config.maps_path, - self.config.split_name, + self.maps_path, + self.split_name, data_group, split, selection, - self.config.mode, + self.mode, verbose=False, ) validation_df = get_prediction( - self.config.maps_path, - self.config.split_name, + self.maps_path, + self.split_name, validation_dataset, split, selection, - self.config.mode, + self.mode, verbose=False, ) performance_dir = ( - self.config.maps_path - / f"{self.config.split_name}-{split}" + self.maps_path + / f"{self.split_name}-{split}" / f"best-{selection}" / data_group ) @@ -411,13 +507,13 @@ def _ensemble_to_tsv( performance_dir.mkdir(parents=True, exist_ok=True) df_final, metrics = ensemble_prediction( - self.config.mode, - self.config.metrics_module, - self.config.n_classes, - self.config.network_task, + self.mode, + self.metrics_module, + self.n_classes, + self.network_task, test_df, validation_df, - selection_threshold=self.config.selection_threshold, + selection_threshold=self.selection_threshold, use_labels=use_labels, ) @@ -452,19 +548,19 @@ def _mode_to_image_tsv( """ sub_df = get_prediction( - self.config.maps_path, - self.config.split_name, + self.maps_path, + self.split_name, data_group, split, selection, - self.config.mode, + self.mode, verbose=False, ) - sub_df.rename(columns={f"{self.config.mode}_id": "image_id"}, inplace=True) + sub_df.rename(columns={f"{self.mode}_id": "image_id"}, inplace=True) performance_dir = ( - self.config.maps_path - / f"{self.config.split_name}-{split}" + self.maps_path + / f"{self.split_name}-{split}" / f"best-{selection}" / data_group ) @@ -475,11 +571,11 @@ def _mode_to_image_tsv( ) if use_labels: metrics_df = pd.read_csv( - performance_dir / f"{data_group}_{self.config.mode}_level_metrics.tsv", + performance_dir / f"{data_group}_{self.mode}_level_metrics.tsv", sep="\t", ) - if f"{self.config.mode}_id" in metrics_df: - del metrics_df[f"{self.config.mode}_id"] + if f"{self.mode}_id" in metrics_df: + del metrics_df[f"{self.mode}_id"] metrics_df.to_csv( (performance_dir / f"{data_group}_image_level_metrics.tsv"), index=False, @@ -489,11 +585,9 @@ def _mode_to_image_tsv( ############################### # Objects initialization # ############################### - - # TODO: to put in ClinicaDL model ? def _init_model( self, - transfer_path: Optional[Path] = None, + transfer_path: Path = None, transfer_selection=None, nb_unfrozen_layer=0, split=None, @@ -514,9 +608,9 @@ def _init_model( """ import clinicadl.network as network_package - logger.debug(f"Initialization of model {self.config.architecture}") + logger.debug(f"Initialization of model {self.architecture}") # or choose to implement a dictionary - model_class = getattr(network_package, self.config.architecture) + model_class = getattr(network_package, self.architecture) args = list( model_class.__init__.__code__.co_varnames[ : model_class.__init__.__code__.co_argcount @@ -525,7 +619,7 @@ def _init_model( args.remove("self") kwargs = dict() for arg in args: - kwargs[arg] = self.config.model_dump()[arg] + kwargs[arg] = self.parameters[arg] # Change device from the training parameters if gpu is not None: @@ -542,8 +636,8 @@ def _init_model( if resume: checkpoint_path = ( - self.config.maps_path - / f"{self.config.split_name}-{split}" + self.maps_path + / f"{self.split_name}-{split}" / "tmp" / "checkpoint.pth.tar" ) @@ -580,12 +674,32 @@ def _init_model( return model, current_epoch - # TODO: To put in the splitter + def _init_split_manager(self, split_list=None, ssda_bool: bool = False): + from clinicadl.validation import split_manager + + split_class = getattr(split_manager, self.validation) + args = list( + split_class.__init__.__code__.co_varnames[ + : split_class.__init__.__code__.co_argcount + ] + ) + args.remove("self") + args.remove("split_list") + kwargs = {"split_list": split_list} + for arg in args: + kwargs[arg] = self.parameters[arg] + + if ssda_bool: + kwargs["caps_directory"] = self.caps_target + kwargs["tsv_path"] = self.tsv_target_lab + + return split_class(**kwargs) + def _init_split_manager_ssda(self, caps_dir, tsv_dir, split_list=None): # A intégrer directement dans _init_split_manager from clinicadl.validation import split_manager - split_class = getattr(split_manager, self.config.validation) + split_class = getattr(split_manager, self.validation) args = list( split_class.__init__.__code__.co_varnames[ : split_class.__init__.__code__.co_argcount @@ -595,13 +709,37 @@ def _init_split_manager_ssda(self, caps_dir, tsv_dir, split_list=None): args.remove("split_list") kwargs = {"split_list": split_list} for arg in args: - kwargs[arg] = self.config.model_dump()[arg] + kwargs[arg] = self.parameters[arg] kwargs["caps_directory"] = Path(caps_dir) kwargs["tsv_path"] = Path(tsv_dir) return split_class(**kwargs) + # def _init_task_manager( + # self, df: Optional[pd.DataFrame] = None, n_classes: Optional[int] = None + # ): + # from clinicadl.utils.task_manager import ( + # ClassificationManager, + # ReconstructionManager, + # RegressionManager, + # ) + + # if self.network_task == "classification": + # if n_classes is not None: + # return ClassificationManager(self.mode, n_classes=n_classes) + # else: + # return ClassificationManager(self.mode, df=df, label=self.label) + # elif self.network_task == "regression": + # return RegressionManager(self.mode) + # elif self.network_task == "reconstruction": + # return ReconstructionManager(self.mode) + # else: + # raise NotImplementedError( + # f"Task {self.network_task} is not implemented in ClinicaDL. " + # f"Please choose between classification, regression and reconstruction." + # ) + ############################### # Getters # ############################### @@ -620,8 +758,8 @@ def _print_description_log( selection_metric (str): Metric used for best weights selection. """ log_dir = ( - self.config.maps_path - / f"{self.config.split_name}-{split}" + self.maps_path + / f"{self.split_name}-{split}" / f"best-{selection_metric}" / data_group ) @@ -629,11 +767,42 @@ def _print_description_log( with log_path.open(mode="r") as f: content = f.read() - def read_maps_json(self): - """Returns the training dictionary.""" - json_path = self.config.maps_path / "maps.json" + def get_parameters(self): + """Returns the training parameters dictionary.""" + json_path = self.maps_path / "maps.json" return read_json(json_path) + # never used ?? + # def get_model( + # self, split: int = 0, selection_metric: str = None, network: int = None + # ) -> Network: + # selection_metric = self._check_selection_metric(split, selection_metric) + # if self.multi_network: + # if network is None: + # raise ClinicaDLArgumentError( + # "Please precise the network number that must be loaded." + # ) + # return self._init_model( + # self.maps_path, + # selection_metric, + # split, + # network=network, + # nb_unfrozen_layer=self.nb_unfrozen_layer, + # )[0] + + # def get_best_epoch( + # self, split: int = 0, selection_metric: str = None, network: int = None + # ) -> int: + # selection_metric = self._check_selection_metric(split, selection_metric) + # if self.multi_network: + # if network is None: + # raise ClinicaDLArgumentError( + # "Please precise the network number that must be loaded." + # ) + # return self.get_state_dict(split=split, selection_metric=selection_metric)[ + # "epoch" + # ] + def get_state_dict( self, split=0, @@ -662,24 +831,24 @@ def get_state_dict( (Dict): dictionary of results (weights, epoch number, metrics values) """ selection_metric = check_selection_metric( - self.config.maps_path, self.config.split_name, split, selection_metric + self.maps_path, self.split_name, split, selection_metric ) - if self.config.multi_network: + if self.multi_network: if network is None: raise ClinicaDLArgumentError( "Please precise the network number that must be loaded." ) else: model_path = ( - self.config.maps_path - / f"{self.config.split_name}-{split}" + self.maps_path + / f"{self.split_name}-{split}" / f"best-{selection_metric}" / f"network-{network}_model.pth.tar" ) else: model_path = ( - self.config.maps_path - / f"{self.config.split_name}-{split}" + self.maps_path + / f"{self.split_name}-{split}" / f"best-{selection_metric}" / "model.pth.tar" ) @@ -698,4 +867,4 @@ def std_amp(self) -> bool: distinguishing the base DDP with AMP and the usage of FSDP with AMP which then calls the internal FSDP AMP mechanisms. """ - return self.config.amp and not self.config.fully_sharded_data_parallel + return self.amp and not self.fully_sharded_data_parallel diff --git a/clinicadl/maps_manager/new_aps_manager.py b/clinicadl/maps_manager/new_aps_manager.py deleted file mode 100644 index 116f65eb8..000000000 --- a/clinicadl/maps_manager/new_aps_manager.py +++ /dev/null @@ -1,175 +0,0 @@ -import json -import subprocess -from datetime import datetime -from logging import getLogger -from pathlib import Path -from typing import Any, Dict, List, Optional, Union - -import pandas as pd -import torch -import torch.distributed as dist -from torch.amp import autocast - -from clinicadl.caps_dataset.caps_dataset_utils import read_json -from clinicadl.caps_dataset.data import ( - return_dataset, -) -from clinicadl.metrics.metric_module import MetricModule -from clinicadl.metrics.utils import ( - check_selection_metric, - find_selection_metrics, -) -from clinicadl.predict.utils import get_prediction -from clinicadl.splitter.split_utils import init_split_manager -from clinicadl.trainer.tasks_utils import ( - ensemble_prediction, - evaluation_metrics, - generate_label_code, - output_size, - test, - test_da, -) -from clinicadl.transforms.config import TransformsConfig -from clinicadl.utils import cluster -from clinicadl.utils.computational.ddp import DDP, init_ddp -from clinicadl.utils.exceptions import ( - ClinicaDLArgumentError, - ClinicaDLConfigurationError, - MAPSError, -) -from clinicadl.utils.iotools.maps_manager_utils import ( - add_default_values, -) -from clinicadl.utils.iotools.utils import path_encoder - -logger = getLogger("clinicadl.maps_manager") - - -class InOutManager: - def __init__(self, maps_path: Path, config): - self.maps_path = maps_path - - def _write_requirements_version(self): - """Writes the environment.txt file.""" - logger.debug("Writing requirement version...") - try: - env_variables = subprocess.check_output("pip freeze", shell=True).decode( - "utf-8" - ) - with (self.maps_path / "environment.txt").open(mode="w") as file: - file.write(env_variables) - except subprocess.CalledProcessError: - logger.warning( - "You do not have the right to execute pip freeze. Your environment will not be written" - ) - - def _write_training_data( - self, - tsv_path: Path, - diagnoses: List[str], - multi_cohort: bool, - transfer_path: Optional[Path] = None, - ): - """Writes the TSV file containing the participant and session IDs used for training.""" - logger.debug("Writing training data...") - from clinicadl.utils.iotools.data_utils import load_data_test - - train_df = load_data_test( - tsv_path, # self.tsv_path, - diagnoses, # self.diagnoses, - baseline=False, - multi_cohort=multi_cohort, # self.multi_cohort, - ) - train_df = train_df[["participant_id", "session_id"]] - if transfer_path: # self.transfer_path: - transfer_train_path = transfer_path / "groups" / "train+validation.tsv" - transfer_train_df = pd.read_csv(transfer_train_path, sep="\t") - transfer_train_df = transfer_train_df[["participant_id", "session_id"]] - train_df = pd.concat([train_df, transfer_train_df]) - train_df.drop_duplicates(inplace=True) - train_df.to_csv( - self.maps_path / "groups" / "train+validation.tsv", sep="\t", index=False - ) - - def _write_train_val_groups( - self, label: str, split_name: str, caps_directory: Path, multi_cohort: bool - ): - """Defines the training and validation groups at the initialization""" - logger.debug("Writing training and validation groups...") - split_manager = init_split_manager() - for split in split_manager.split_iterator(): - for data_group in ["train", "validation"]: - df = split_manager[split][data_group] - group_path = ( - self.maps_path / "groups" / data_group / f"{split_name}-{split}" - ) - group_path.mkdir(parents=True, exist_ok=True) - - columns = ["participant_id", "session_id", "cohort"] - if label is not None: - columns.append(label) - df.to_csv(group_path / "data.tsv", sep="\t", columns=columns) - self.write_parameters( - group_path, - { - "caps_directory": caps_directory, - "multi_cohort": multi_cohort, - }, - verbose=False, - ) - - def _write_information(self, architecture, parameters): - """ - Writes model architecture of the MAPS in MAPS root. - """ - from datetime import datetime - - import clinicadl.network as network_package - - model_class = getattr(network_package, architecture) - args = list( - model_class.__init__.__code__.co_varnames[ - : model_class.__init__.__code__.co_argcount - ] - ) - args.remove("self") - kwargs = dict() - for arg in args: - kwargs[arg] = parameters[arg] - kwargs["gpu"] = False - - model = model_class(**kwargs) - - file_name = "information.log" - - with (self.maps_path / file_name).open(mode="w") as f: - f.write(f"- Date :\t{datetime.now().strftime('%d %b %Y, %H:%M:%S')}\n\n") - f.write(f"- Path :\t{self.maps_path}\n\n") - # f.write("- Job ID :\t{}\n".format(os.getenv('SLURM_JOBID'))) - f.write(f"- Model :\t{model.layers}\n\n") - - del model - - @staticmethod - def write_description_log( - log_dir: Path, - data_group, - caps_dict, - df, - ): - """ - Write description log file associated to a data group. - - Args: - log_dir (str): path to the log file directory. - data_group (str): name of the data group used for the task. - caps_dict (dict[str, str]): Dictionary of the CAPS folders used for the task - df (pd.DataFrame): DataFrame of the meta-data used for the task. - """ - log_dir.mkdir(parents=True, exist_ok=True) - log_path = log_dir / "description.log" - with log_path.open(mode="w") as f: - f.write(f"Prediction {data_group} group - {datetime.now()}\n") - f.write(f"Data loaded from CAPS directories: {caps_dict}\n") - f.write(f"Number of participants: {df.participant_id.nunique()}\n") - f.write(f"Number of sessions: {len(df)}\n") diff --git a/clinicadl/trainer/trainer.py b/clinicadl/trainer/trainer.py index 651511a0d..fb04e5920 100644 --- a/clinicadl/trainer/trainer.py +++ b/clinicadl/trainer/trainer.py @@ -35,7 +35,6 @@ from clinicadl.trainer.tasks_utils import create_training_config from clinicadl.validator.validator import Validator from clinicadl.splitter.split_utils import init_split_manager -from clinicadl.maps_manager.tmp_config import TmpConfig if TYPE_CHECKING: from clinicadl.callbacks.callbacks import Callback @@ -64,7 +63,6 @@ def __init__( config : TrainConfig """ self.config = config - self.tmp_config = TmpConfig(**config.model_dump()) self.maps_manager = self._init_maps_manager(config) self.validator = Validator() @@ -168,8 +166,8 @@ def resume(self, splits: List[int]) -> None: ) # TODO : check these two lines. Why do we need a split_manager? split_manager = init_split_manager( - validation=self.tmp_config.validation, - parameters=self.tmp_config.model_dump(), + validation=self.config.validation, + parameters=self.config.model_dump(), split_list=splits, ) split_iterator = split_manager.split_iterator() @@ -232,7 +230,7 @@ def train( def check_split_list(self, split_list, overwrite): existing_splits = [] split_manager = init_split_manager( - self.tmp_config.validation, self.tmp_config.model_dump(), split_list + self.config.validation, self.config.model_dump(), split_list ) for split in split_manager.split_iterator(): split_path = ( @@ -272,7 +270,7 @@ def _resume( """ missing_splits = [] split_manager = init_split_manager( - self.tmp_config.validation, self.tmp_config.model_dump(), split_list + self.config.validation, self.config.model_dump(), split_list ) for split in split_manager.split_iterator(): @@ -374,7 +372,7 @@ def _train_single( """ split_manager = init_split_manager( - self.tmp_config.validation, self.tmp_config.model_dump(), split_list + self.config.validation, self.config.model_dump(), split_list ) for split in split_manager.split_iterator(): logger.info(f"Training split {split}") @@ -472,7 +470,7 @@ def _train_multi( # train_transforms, all_transforms = self.config.transforms.get_transforms() split_manager = init_split_manager( - self.tmp_config.validation, self.tmp_config.model_dump(), split_list + self.config.validation, self.config.model_dump(), split_list ) for split in split_manager.split_iterator(): logger.info(f"Training split {split}")