From bf67c1cf0c744e36cf5a775df5ea0dd2de61dae8 Mon Sep 17 00:00:00 2001 From: camillebrianceau <57992134+camillebrianceau@users.noreply.github.com> Date: Tue, 1 Oct 2024 16:19:34 +0200 Subject: [PATCH] Clean MapsManager (#658) * create the get_dataloader function * remove the train multi --- clinicadl/API_test.py | 59 ++- clinicadl/maps_manager/tmp_config.py | 513 +++++++++++++++++++++++++++ clinicadl/splitter/split_utils.py | 31 +- clinicadl/trainer/trainer.py | 430 +++++++++++----------- 4 files changed, 809 insertions(+), 224 deletions(-) create mode 100644 clinicadl/maps_manager/tmp_config.py diff --git a/clinicadl/API_test.py b/clinicadl/API_test.py index a7240c92d..5f17c044c 100644 --- a/clinicadl/API_test.py +++ b/clinicadl/API_test.py @@ -1,4 +1,7 @@ +from pathlib import Path + from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig +from clinicadl.caps_dataset.data import return_dataset from clinicadl.prepare_data.prepare_data import DeepLearningPrepareData from clinicadl.trainer.config.classification import ClassificationConfig from clinicadl.trainer.trainer import Trainer @@ -12,6 +15,56 @@ DeepLearningPrepareData(image_config) -config = ClassificationConfig() -trainer = Trainer(config) -trainer.train(split_list=config.cross_validation.split, overwrite=True) +dataset = return_dataset( + input_dir, + data_df, + preprocessing_dict, + transforms_config, + label, + label_code, + cnn_index, + label_presence, + multi_cohort, +) + +split_config = SplitConfig() +splitter = Splitter(split_config) + +validator_config = ValidatorConfig() +validator = Validator(validator_config) + +train_config = ClassificationConfig() +trainer = Trainer(train_config, validator) + +for split in splitter.split_iterator(): + for network in range( + first_network, self.maps_manager.num_networks + ): # for multi_network + ###### actual _train_single method of the Trainer ############ + train_loader = trainer.get_dataloader(dataset, split, network, "train", config) + valid_loader = validator.get_dataloader( + dataset, split, network, "valid", config + ) # ?? validatior, trainer ? + + trainer._train( + train_loader, + valid_loader, + split=split, + network=network, + resume=resume, # in a config class + callbacks=[CodeCarbonTracker], # in a config class ? + ) + + validator._ensemble_prediction( + self.maps_manager, + "train", + split, + self.config.validation.selection_metrics, + ) + validator._ensemble_prediction( + self.maps_manager, + "validation", + split, + self.config.validation.selection_metrics, + ) + ###### end ############ diff --git a/clinicadl/maps_manager/tmp_config.py b/clinicadl/maps_manager/tmp_config.py new file mode 100644 index 000000000..84e9e464c --- /dev/null +++ b/clinicadl/maps_manager/tmp_config.py @@ -0,0 +1,513 @@ +from abc import ABC, abstractmethod +from logging import getLogger +from pathlib import Path +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +import pandas as pd +import torchvision.transforms as torch_transforms +from pydantic import ( + BaseModel, + ConfigDict, + NegativeInt, + NonNegativeFloat, + NonNegativeInt, + PositiveFloat, + PositiveInt, + computed_field, + field_validator, + model_validator, +) +from typing_extensions import Self + +from clinicadl.caps_dataset.data import return_dataset +from clinicadl.metrics.metric_module import MetricModule +from clinicadl.splitter.split_utils import find_splits +from clinicadl.trainer.tasks_utils import ( + evaluation_metrics, + generate_label_code, + get_default_network, + output_size, +) +from clinicadl.transforms import transforms +from clinicadl.transforms.config import TransformsConfig +from clinicadl.utils.enum import ( + Compensation, + ExperimentTracking, + Mode, + Optimizer, + Sampler, + SizeReductionFactor, + Task, + Transform, +) +from clinicadl.utils.exceptions import ( + ClinicaDLArgumentError, + ClinicaDLConfigurationError, + ClinicaDLTSVError, +) +from clinicadl.utils.iotools.data_utils import check_multi_cohort_tsv, load_data_test +from clinicadl.utils.iotools.utils import read_preprocessing + +logger = getLogger("clinicadl.tmp") + +#### this class is not used in clinicadl but I just want to keep it in case + + +class TmpConfig(BaseModel): + """ + arguments needed : caps_directory, maps_path, loss + """ + + output_size: Optional[int] = None + n_classes: Optional[int] = None + network_task: Optional[str] = None + metrics_module: Optional[MetricModule] = None + split_name: Optional[str] = None + selection_threshold: Optional[int] = None + num_networks: Optional[int] = None + input_size: Optional[Sequence[int]] = None + validation: str = "SingleSplit" + std_amp: Optional[bool] = None + preprocessing_dict: Optional[dict] = None + + emissions_calculator: bool = False + track_exp: Optional[ExperimentTracking] = None + + amp: bool = False + fully_sharded_data_parallel: bool = False + gpu: bool = True + + n_splits: NonNegativeInt = 0 + split: Optional[Tuple[NonNegativeInt, ...]] = None + tsv_path: Optional[Path] = None # not needed in predict ? + + caps_directory: Path + baseline: bool = False + diagnoses: Tuple[str, ...] = ("AD", "CN") + data_df: Optional[pd.DataFrame] = None + label: Optional[str] = None + label_code: Optional[Dict[str, int]] = None + multi_cohort: bool = False + mask_path: Optional[Path] = None + preprocessing_json: Optional[Path] = None + data_tsv: Optional[Path] = None + n_subjects: int = 300 + + batch_size: PositiveInt = 8 + n_proc: PositiveInt = 2 + sampler: Sampler = Sampler.RANDOM + + patience: NonNegativeInt = 0 + tolerance: NonNegativeFloat = 0.0 + + adaptive_learning_rate: bool = False + + maps_path: Path + data_group: Optional[str] = None + overwrite: bool = False + save_nifti: bool = False + + architecture: str = "default" + dropout: NonNegativeFloat = 0.0 + loss: str + multi_network: bool = False + + accumulation_steps: PositiveInt = 1 + epochs: PositiveInt = 20 + profiler: bool = False + + learning_rate: PositiveFloat = 1e-4 + optimizer: Optimizer = Optimizer.ADAM + weight_decay: NonNegativeFloat = 1e-4 + + compensation: Compensation = Compensation.MEMORY + deterministic: bool = False + save_all_models: bool = False + seed: int = 0 + config_file: Optional[Path] = None + + caps_target: Path = Path("") + preprocessing_json_target: Path = Path("") + ssda_network: bool = False + tsv_target_lab: Path = Path("") + tsv_target_unlab: Path = Path("") + + nb_unfrozen_layer: NonNegativeInt = 0 + transfer_path: Optional[Path] = None + transfer_selection_metric: str = "loss" + + data_augmentation: Tuple[Transform, ...] = () + train_transformations: Optional[Tuple[Transform, ...]] = None + normalize: bool = True + size_reduction: bool = False + size_reduction_factor: SizeReductionFactor = SizeReductionFactor.TWO + + evaluation_steps: NonNegativeInt = 0 + selection_metrics: Tuple[str, ...] = () + valid_longitudinal: bool = False + skip_leak_check: bool = False + + # pydantic config + model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True) + + @model_validator(mode="after") + def check_mandatory_args(self) -> Self: + if self.caps_directory is None: + raise ClinicaDLArgumentError( + "caps_directory is a mandatory argument and it's set to None" + ) + if self.tsv_path is None: + raise ClinicaDLArgumentError( + "tsv_path is a mandatory argument and it's set to None" + ) + if self.preprocessing_dict is None: + raise ClinicaDLArgumentError( + "preprocessing_dict is a mandatory argument and it's set to None" + ) + if self.mode is None: + raise ClinicaDLArgumentError( + "mode is a mandatory argument and it's set to None" + ) + if self.network_task is None: + raise ClinicaDLArgumentError( + "network_task is a mandatory argument and it's set to None" + ) + return self + + def check_args(self): + transfo_config = TransformsConfig( + normalize=self.normalize, + size_reduction=self.size_reduction, + size_reduction_factor=self.size_reduction_factor, + ) + + if self.network_task == "classification": + from clinicadl.splitter.split_utils import init_split_manager + + if self.n_splits > 1 and self.validation == "SingleSplit": + self.validation = "KFoldSplit" + + split_manager = init_split_manager( + validation=self.validation, + parameters=self.model_dump(), + split_list=None, + caps_target=self.caps_target, + tsv_target_lab=self.tsv_target_lab, + ) + train_df = split_manager[0]["train"] + self.n_classes = output_size(self.network_task, None, train_df, self.label) + self.metrics_module = MetricModule( + evaluation_metrics(self.network_task), n_classes=self.n_classes + ) + + elif self.network_task == "regression" or self.network_task == "reconstruction": + self.metrics_module = MetricModule( + evaluation_metrics(self.network_task), n_classes=None + ) + + else: + raise NotImplementedError( + f"Task {self.network_task} is not implemented in ClinicaDL. " + f"Please choose between classification, regression and reconstruction." + ) + + if self.architecture == "default": + self.architecture = get_default_network(self.network_task) + + if (self.label_code is None) or ( + len(self.label_code) == 0 + ): # Allows to set custom label code in TOML + self.label_code = generate_label_code( + self.network_task, train_df, self.label + ) + + full_dataset = return_dataset( + self.caps_directory, + train_df, + self.preprocessing_dict, + multi_cohort=self.multi_cohort, + label=self.label, + label_code=self.label_code, + transforms_config=transfo_config, + ) + self.num_networks = full_dataset.elem_per_image + self.output_size = output_size( + self.network_task, full_dataset.size, full_dataset.df, self.label + ) + self.input_size = full_dataset.size + + if self.num_networks < 2 and self.multi_network: + raise ClinicaDLConfigurationError( + f"Invalid training configuration: cannot train a multi-network " + f"framework with only {self.num_networks} element " + f"per image." + ) + possible_selection_metrics_set = set(evaluation_metrics(self.network_task)) | { + "loss" + } + if not set(self.selection_metrics).issubset(possible_selection_metrics_set): + raise ClinicaDLConfigurationError( + f"Selection metrics {self.selection_metrics} " + f"must be a subset of metrics used for evaluation " + f"{possible_selection_metrics_set}." + ) + + @model_validator(mode="after") + def check_gpu(self) -> Self: + if self.gpu: + import torch + + if not torch.cuda.is_available(): + raise ClinicaDLArgumentError( + "No GPU is available. To run on CPU, please set gpu to false or add the --no-gpu flag if you use the commandline." + ) + elif self.amp: + raise ClinicaDLArgumentError( + "AMP is designed to work with modern GPUs. Please add the --gpu flag." + ) + return self + + @field_validator("track_exp", mode="before") + def check_track_exp(cls, v): + if v == "": + return None + + @field_validator("split", "diagnoses", "selection_metrics", mode="before") + def list_to_tuples(cls, v): + if isinstance(v, list): + return tuple(v) + return v # TODO : check that split exists (and check coherence with n_splits) + + def adapt_cross_val_with_maps_manager_info( + self, maps_manager + ): # maps_manager is of type MapsManager but need to be in a MapsConfig type in the future + # TEMPORARY + if not self.split: + self.split = find_splits(maps_manager.maps_path, maps_manager.split_name) + logger.debug(f"List of splits {self.split}") + + def create_groupe_df(self): + group_df = None + if self.data_tsv is not None and self.data_tsv.is_file(): + group_df = load_data_test( + self.data_tsv, + self.diagnoses, + multi_cohort=self.multi_cohort, + ) + return group_df + + def is_given_label_code(self, _label: str, _label_code: Union[str, Dict[str, int]]): + return ( + self.label is not None + and self.label != "" + and self.label != _label + and _label_code == "default" + ) + + def check_label(self, _label: str): + if not self.label: + self.label = _label + + @field_validator("data_tsv", mode="before") + @classmethod + def check_data_tsv(cls, v) -> Path: + if v is not None: + if not isinstance(v, Path): + v = Path(v) + if not v.is_file(): + raise ClinicaDLTSVError( + "The participants_list you gave is not a file. Please give an existing file." + ) + if v.stat().st_size == 0: + raise ClinicaDLTSVError( + "The participants_list you gave is empty. Please give a non-empty file." + ) + return v + + @computed_field + @property + def caps_dict(self) -> Dict[str, Path]: + from clinicadl.utils.iotools.clinica_utils import check_caps_folder + + if self.multi_cohort: + if self.caps_directory.suffix != ".tsv": + raise ClinicaDLArgumentError( + "If multi_cohort is True, the CAPS_DIRECTORY argument should be a path to a TSV file." + ) + else: + caps_df = pd.read_csv(self.caps_directory, sep="\t") + check_multi_cohort_tsv(caps_df, "CAPS") + caps_dict = dict() + for idx in range(len(caps_df)): + cohort = caps_df.loc[idx, "cohort"] + caps_path = Path(caps_df.at[idx, "path"]) + check_caps_folder(caps_path) + caps_dict[cohort] = caps_path + else: + check_caps_folder(self.caps_directory) + caps_dict = {"single": self.caps_directory} + + return caps_dict + + @model_validator(mode="after") + def check_preprocessing_dict(self) -> Self: + """ + Gets the preprocessing dictionary from a preprocessing json file. + + Returns + ------- + Dict[str, Any] + The preprocessing dictionary. + + Raises + ------ + ValueError + In case of multi-cohort dataset, if no preprocessing file is found in any CAPS. + """ + from clinicadl.caps_dataset.data import CapsDataset + + if self.preprocessing_dict is None: + if self.preprocessing_json is not None: + if not self.multi_cohort: + preprocessing_json = ( + self.caps_directory + / "tensor_extraction" + / self.preprocessing_json + ) + else: + caps_dict = self.caps_dict + json_found = False + for caps_name, caps_path in caps_dict.items(): + preprocessing_json = ( + caps_path / "tensor_extraction" / self.preprocessing_json + ) + if preprocessing_json.is_file(): + logger.info( + f"Preprocessing JSON {preprocessing_json} found in CAPS {caps_name}." + ) + json_found = True + if not json_found: + raise ValueError( + f"Preprocessing JSON {self.preprocessing_json} was not found for any CAPS " + f"in {caps_dict}." + ) + + self.preprocessing_dict = read_preprocessing(preprocessing_json) + + if ( + self.preprocessing_dict["mode"] == "roi" + and "roi_background_value" not in self.preprocessing_dict + ): + self.preprocessing_dict["roi_background_value"] = 0 + + return self + + @computed_field + @property + def mode(self) -> Mode: + return Mode(self.preprocessing_dict["mode"]) + + @field_validator("dropout") + def validator_dropout(cls, v): + assert ( + 0 <= v <= 1 + ), f"dropout must be between 0 and 1 but it has been set to {v}." + return v + + @computed_field + @property + def preprocessing_dict_target(self) -> Dict[str, Any]: + """ + Gets the preprocessing dictionary from a target preprocessing json file. + + Returns + ------- + Dict[str, Any] + The preprocessing dictionary. + """ + if not self.ssda_network: + return {} + + preprocessing_json_target = ( + self.caps_target / "tensor_extraction" / self.preprocessing_json_target + ) + + return read_preprocessing(preprocessing_json_target) + + @field_validator("transfer_path", mode="before") + def validator_transfer_path(cls, v): + """Transforms a False to None.""" + if v is False: + return None + return v + + @field_validator("transfer_selection_metric") + def validator_transfer_selection_metric(cls, v): + return v # TODO : check if metric is in transfer MAPS + + @field_validator("data_augmentation", mode="before") + def validator_data_augmentation(cls, v): + """Transforms lists to tuples and False to empty tuple.""" + if isinstance(v, list): + return tuple(v) + if v is False: + return () + return v + + def get_transforms( + self, + ) -> Tuple[torch_transforms.Compose, torch_transforms.Compose]: + """ + Outputs the transformations that will be applied to the dataset + + Args: + normalize: if True will perform MinMaxNormalization. + data_augmentation: list of data augmentation performed on the training set. + + Returns: + transforms to apply in train and evaluation mode / transforms to apply in evaluation mode only. + """ + augmentation_dict = { + "Noise": transforms.RandomNoising(sigma=0.1), + "Erasing": torch_transforms.RandomErasing(), + "CropPad": transforms.RandomCropPad(10), + "Smoothing": transforms.RandomSmoothing(), + "Motion": transforms.RandomMotion((2, 4), (2, 4), 2), + "Ghosting": transforms.RandomGhosting((4, 10)), + "Spike": transforms.RandomSpike(1, (1, 3)), + "BiasField": transforms.RandomBiasField(0.5), + "RandomBlur": transforms.RandomBlur((0, 2)), + "RandomSwap": transforms.RandomSwap(15, 100), + "None": None, + } + + augmentation_list = [] + transformations_list = [] + + if self.data_augmentation: + augmentation_list.extend( + [ + augmentation_dict[augmentation] + for augmentation in self.data_augmentation + ] + ) + + transformations_list.append(transforms.NanRemoval()) + if self.normalize: + transformations_list.append(transforms.MinMaxNormalization()) + if self.size_reduction: + transformations_list.append( + transforms.SizeReduction(self.size_reduction_factor) + ) + + all_transformations = torch_transforms.Compose(transformations_list) + train_transformations = torch_transforms.Compose(augmentation_list) + + return train_transformations, all_transformations + + def check_output_saving_nifti(self, network_task: str) -> None: + # Check if task is reconstruction for "save_tensor" and "save_nifti" + if self.save_nifti and network_task != "reconstruction": + raise ClinicaDLArgumentError( + "Cannot save nifti if the network task is not reconstruction. Please remove --save_nifti option." + ) diff --git a/clinicadl/splitter/split_utils.py b/clinicadl/splitter/split_utils.py index e86ac836d..d42047dcc 100644 --- a/clinicadl/splitter/split_utils.py +++ b/clinicadl/splitter/split_utils.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import List +from typing import List, Optional from clinicadl.utils.exceptions import ClinicaDLArgumentError @@ -66,3 +66,32 @@ def print_description_log( log_path = log_dir / "description.log" with log_path.open(mode="r") as f: content = f.read() + + +def init_split_manager( + validation, + parameters, + split_list=None, + ssda_bool: bool = False, + caps_target: Optional[Path] = None, + tsv_target_lab: Optional[Path] = None, +): + from clinicadl.validation import split_manager + + split_class = getattr(split_manager, validation) + args = list( + split_class.__init__.__code__.co_varnames[ + : split_class.__init__.__code__.co_argcount + ] + ) + args.remove("self") + args.remove("split_list") + kwargs = {"split_list": split_list} + for arg in args: + kwargs[arg] = parameters[arg] + + if ssda_bool: + kwargs["caps_directory"] = caps_target + kwargs["tsv_path"] = tsv_target_lab + + return split_class(**kwargs) diff --git a/clinicadl/trainer/trainer.py b/clinicadl/trainer/trainer.py index 3c279d155..b0a07bc8b 100644 --- a/clinicadl/trainer/trainer.py +++ b/clinicadl/trainer/trainer.py @@ -5,7 +5,7 @@ from datetime import datetime from logging import getLogger from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import pandas as pd import torch @@ -34,6 +34,8 @@ ) from clinicadl.trainer.tasks_utils import create_training_config from clinicadl.validator.validator import Validator +from clinicadl.splitter.split_utils import init_split_manager +from clinicadl.transforms.config import TransformsConfig if TYPE_CHECKING: from clinicadl.callbacks.callbacks import Callback @@ -62,6 +64,7 @@ def __init__( config : TrainConfig """ self.config = config + self.maps_manager = self._init_maps_manager(config) self.validator = Validator() self._check_args() @@ -70,11 +73,13 @@ def _init_maps_manager(self, config) -> MapsManager: # temporary: to match CLI data. TODO : change CLI data parameters, maps_path = create_parameters_dict(config) + if maps_path.is_dir(): return MapsManager( maps_path, verbose=None ) # TODO : precise which parameters in config are useful else: + # parameters["maps_path"] = maps_path return MapsManager( maps_path, parameters, verbose=None ) # TODO : precise which parameters in config are useful @@ -161,6 +166,11 @@ def resume(self, splits: List[int]) -> None: ) ) # TODO : check these two lines. Why do we need a split_manager? + # split_manager = init_split_manager( + # validation=self.maps_manager.validation, + # parameters=self.config.model_dump(), + # split_list=splits, + # ) split_manager = self.maps_manager._init_split_manager(split_list=splits) split_iterator = split_manager.split_iterator() ### @@ -209,9 +219,42 @@ def train( MAPSError If splits specified in input already exist and overwrite is False. """ - existing_splits = [] + self.check_split_list(split_list=split_list, overwrite=overwrite) + + if self.config.ssda.ssda_network: + self._train_ssda(split_list, resume=False) + + else: + split_manager = self.maps_manager._init_split_manager(split_list) + # split_manager = init_split_manager( + # self.maps_manager.validation, self.config.model_dump(), split_list + # ) + for split in split_manager.split_iterator(): + logger.info(f"Training split {split}") + seed_everything( + self.config.reproducibility.seed, + self.config.reproducibility.deterministic, + self.config.reproducibility.compensation, + ) + + split_df_dict = split_manager[split] + + if self.config.model.multi_network: + resume, first_network = self.init_first_network(False, split) + for network in range(first_network, self.maps_manager.num_networks): + self._train_single( + split, split_df_dict, network=network, resume=resume + ) + else: + self._train_single(split, split_df_dict, resume=False) + + def check_split_list(self, split_list, overwrite): + existing_splits = [] split_manager = self.maps_manager._init_split_manager(split_list) + # split_manager = init_split_manager( + # self.maps_manager.validation, self.config.model_dump(), split_list + # ) for split in split_manager.split_iterator(): split_path = ( self.maps_manager.maps_path / f"{self.maps_manager.split_name}-{split}" @@ -230,13 +273,6 @@ def train( f"or use overwrite to erase previously trained splits." ) - if self.config.model.multi_network: - self._train_multi(split_list, resume=False) - elif self.config.ssda.ssda_network: - self._train_ssda(split_list, resume=False) - else: - self._train_single(split_list, resume=False) - def _resume( self, split_list: Optional[List[int]] = None, @@ -256,8 +292,10 @@ def _resume( If splits specified in input do not exist. """ missing_splits = [] + # split_manager = init_split_manager( + # self.maps_manager.validation, self.config.model_dump(), split_list + # ) split_manager = self.maps_manager._init_split_manager(split_list) - for split in split_manager.split_iterator(): if not ( self.maps_manager.maps_path @@ -272,127 +310,116 @@ def _resume( f"Please try train command on these splits and resume only others." ) - if self.config.model.multi_network: - self._train_multi(split_list, resume=True) - elif self.config.ssda.ssda_network: + if self.config.ssda.ssda_network: self._train_ssda(split_list, resume=True) else: - self._train_single(split_list, resume=True) - - def _train_single( - self, - split_list: Optional[List[int]] = None, - resume: bool = False, - ) -> None: - """ - Trains a single CNN for all inputs. - - Parameters - ---------- - split_list : Optional[List[int]] (optional, default=None) - List of splits on which the training task is performed. - If None, performs training on all splits of the cross-validation. - resume : bool (optional, default=False) - If True, the job is resumed from checkpoint. - """ - # train_transforms, all_transforms = self.config.transforms.get_transforms() + for split in split_manager.split_iterator(): + logger.info(f"Training split {split}") + seed_everything( + self.config.reproducibility.seed, + self.config.reproducibility.deterministic, + self.config.reproducibility.compensation, + ) - split_manager = self.maps_manager._init_split_manager(split_list) - for split in split_manager.split_iterator(): - logger.info(f"Training split {split}") - seed_everything( - self.config.reproducibility.seed, - self.config.reproducibility.deterministic, - self.config.reproducibility.compensation, - ) + split_df_dict = split_manager[split] + if self.config.model.multi_network: + resume, first_network = self.init_first_network(True, split) + for network in range(first_network, self.maps_manager.num_networks): + self._train_single( + split, split_df_dict, network=network, resume=resume + ) + else: + self._train_single(split, split_df_dict, resume=True) - split_df_dict = split_manager[split] + def init_first_network(self, resume, split): + first_network = 0 + if resume: + training_logs = [ + int(network_folder.split("-")[1]) + for network_folder in list( + ( + self.maps_manager.maps_path + / f"{self.maps_manager.split_name}-{split}" + / "training_logs" + ).iterdir() + ) + ] + first_network = max(training_logs) + if not (self.maps_manager.maps_path / "tmp").is_dir(): + first_network += 1 + resume = False + return resume, first_network - logger.debug("Loading training data...") - data_train = return_dataset( - self.config.data.caps_directory, - split_df_dict["train"], - self.config.data.preprocessing_dict, - transforms_config=self.config.transforms, - multi_cohort=self.config.data.multi_cohort, - label=self.config.data.label, - label_code=self.maps_manager.label_code, - ) - logger.debug("Loading validation data...") - data_valid = return_dataset( - self.config.data.caps_directory, - split_df_dict["validation"], - self.config.data.preprocessing_dict, - transforms_config=self.config.transforms, - multi_cohort=self.config.data.multi_cohort, - label=self.config.data.label, - label_code=self.maps_manager.label_code, - ) - train_sampler = generate_sampler( - self.maps_manager.network_task, - data_train, - self.config.dataloader.sampler, - dp_degree=cluster.world_size, - rank=cluster.rank, - ) - logger.debug( - f"Getting train and validation loader with batch size {self.config.dataloader.batch_size}" - ) - train_loader = DataLoader( - data_train, - batch_size=self.config.dataloader.batch_size, - sampler=train_sampler, - num_workers=self.config.dataloader.n_proc, - worker_init_fn=pl_worker_init_function, - ) - logger.debug(f"Train loader size is {len(train_loader)}") - valid_sampler = DistributedSampler( - data_valid, - num_replicas=cluster.world_size, - rank=cluster.rank, - shuffle=False, - ) - valid_loader = DataLoader( - data_valid, - batch_size=self.config.dataloader.batch_size, - shuffle=False, - num_workers=self.config.dataloader.n_proc, - sampler=valid_sampler, + def get_dataloader( + self, + input_dir: Path, + data_df: pd.DataFrame, + preprocessing_dict: Dict[str, Any], + transforms_config: TransformsConfig, + label: Optional[str] = None, + label_code: Optional[Dict[str, int]] = None, + cnn_index: Optional[int] = None, + label_presence: bool = True, + multi_cohort: bool = False, + network_task: Union[str, Task] = "classification", + sampler_option: str = "random", + n_bins: int = 5, + dp_degree: Optional[int] = None, + rank: Optional[int] = None, + batch_size: Optional[int] = None, + n_proc: Optional[int] = None, + worker_init_fn: Optional[function] = None, + shuffle: Optional[bool] = None, + num_replicas: Optional[int] = None, + homemade_sampler: bool = False, + ): + dataset = return_dataset( + input_dir=input_dir, + data_df=data_df, + preprocessing_dict=preprocessing_dict, + transforms_config=transforms_config, + multi_cohort=multi_cohort, + label=label, + label_code=label_code, + cnn_index=cnn_index, + ) + if homemade_sampler: + sampler = generate_sampler( + network_task=network_task, + dataset=dataset, + sampler_option=sampler_option, + dp_degree=dp_degree, + rank=rank, ) - logger.debug(f"Validation loader size is {len(valid_loader)}") - from clinicadl.callbacks.callbacks import CodeCarbonTracker - - self._train( - train_loader, - valid_loader, - split, - resume=resume, - callbacks=[CodeCarbonTracker], + else: + sampler = DistributedSampler( + dataset, + num_replicas=num_replicas, + rank=rank, + shuffle=shuffle, ) - if cluster.master: - self.validator._ensemble_prediction( - self.maps_manager, - "train", - split, - self.config.validation.selection_metrics, - ) - self.validator._ensemble_prediction( - self.maps_manager, - "validation", - split, - self.config.validation.selection_metrics, - ) + train_loader = DataLoader( + dataset=dataset, + batch_size=batch_size, + sampler=sampler, + num_workers=n_proc, + worker_init_fn=worker_init_fn, + shuffle=shuffle, + ) + logger.debug(f"Train loader size is {len(train_loader)}") - self._erase_tmp(split) + return train_loader - def _train_multi( + def _train_single( self, - split_list: Optional[List[int]] = None, + split, + split_df_dict: Dict, + network: Optional[int] = None, resume: bool = False, ) -> None: """ - Trains a CNN per element in the image (e.g. per slice). + Trains a single CNN for all inputs. Parameters ---------- @@ -402,115 +429,79 @@ def _train_multi( resume : bool (optional, default=False) If True, the job is resumed from checkpoint. """ - # train_transforms, all_transforms = self.config.transforms.get_transforms() - - split_manager = self.maps_manager._init_split_manager(split_list) - for split in split_manager.split_iterator(): - logger.info(f"Training split {split}") - seed_everything( - self.config.reproducibility.seed, - self.config.reproducibility.deterministic, - self.config.reproducibility.compensation, - ) - split_df_dict = split_manager[split] + logger.debug("Loading training data...") + + train_loader = self.get_dataloader( + input_dir=self.config.data.caps_directory, + data_df=split_df_dict["train"], + preprocessing_dict=self.config.data.preprocessing_dict, + transforms_config=self.config.transforms, + multi_cohort=self.config.data.multi_cohort, + label=self.config.data.label, + label_code=self.maps_manager.label_code, + cnn_index=network, + network_task=self.maps_manager.network_task, + sampler_option=self.config.dataloader.sampler, + dp_degree=cluster.world_size, + rank=cluster.rank, + batch_size=self.config.dataloader.batch_size, + n_proc=self.config.dataloader.n_proc, + worker_init_fn=pl_worker_init_function, + homemade_sampler=True, + ) - first_network = 0 - if resume: - training_logs = [ - int(network_folder.split("-")[1]) - for network_folder in list( - ( - self.maps_manager.maps_path - / f"{self.maps_manager.split_name}-{split}" - / "training_logs" - ).iterdir() - ) - ] - first_network = max(training_logs) - if not (self.maps_manager.maps_path / "tmp").is_dir(): - first_network += 1 - resume = False - - for network in range(first_network, self.maps_manager.num_networks): - logger.info(f"Train network {network}") - - data_train = return_dataset( - self.config.data.caps_directory, - split_df_dict["train"], - self.config.data.preprocessing_dict, - transforms_config=self.config.transforms, - multi_cohort=self.config.data.multi_cohort, - label=self.config.data.label, - label_code=self.maps_manager.label_code, - cnn_index=network, - ) - data_valid = return_dataset( - self.config.data.caps_directory, - split_df_dict["validation"], - self.config.data.preprocessing_dict, - transforms_config=self.config.transforms, - multi_cohort=self.config.data.multi_cohort, - label=self.config.data.label, - label_code=self.maps_manager.label_code, - cnn_index=network, - ) + logger.debug(f"Train loader size is {len(train_loader)}") + logger.debug("Loading validation data...") + + valid_loader = self.get_dataloader( + input_dir=self.config.data.caps_directory, + data_df=split_df_dict["validation"], + preprocessing_dict=self.config.data.preprocessing_dict, + transforms_config=self.config.transforms, + multi_cohort=self.config.data.multi_cohort, + label=self.config.data.label, + label_code=self.maps_manager.label_code, + cnn_index=network, + network_task=self.maps_manager.network_task, + num_replicas=cluster.world_size, + rank=cluster.rank, + batch_size=self.config.dataloader.batch_size, + n_proc=self.config.dataloader.n_proc, + shuffle=False, + homemade_sampler=False, + ) - train_sampler = generate_sampler( - self.maps_manager.network_task, - data_train, - self.config.dataloader.sampler, - dp_degree=cluster.world_size, - rank=cluster.rank, - ) - train_loader = DataLoader( - data_train, - batch_size=self.config.dataloader.batch_size, - sampler=train_sampler, - num_workers=self.config.dataloader.n_proc, - worker_init_fn=pl_worker_init_function, - ) + logger.debug(f"Validation loader size is {len(valid_loader)}") + from clinicadl.callbacks.callbacks import CodeCarbonTracker - valid_sampler = DistributedSampler( - data_valid, - num_replicas=cluster.world_size, - rank=cluster.rank, - shuffle=False, - ) - valid_loader = DataLoader( - data_valid, - batch_size=self.config.dataloader.batch_size, - shuffle=False, - num_workers=self.config.dataloader.n_proc, - sampler=valid_sampler, - ) - from clinicadl.callbacks.callbacks import CodeCarbonTracker + self._train( + train_loader, + valid_loader, + split, + resume=resume, + callbacks=[CodeCarbonTracker], + network=network, + ) - self._train( - train_loader, - valid_loader, - split, - network, - resume=resume, - callbacks=[CodeCarbonTracker], - ) - resume = False + if network is not None: + resume = False - if cluster.master: - self.validator._ensemble_prediction( - self.maps_manager, - "train", - split, - self.config.validation.selection_metrics, - ) - self.validator._ensemble_prediction( - self.maps_manager, - "validation", - split, - self.config.validation.selection_metrics, - ) + if cluster.master: + self.validator._ensemble_prediction( + self.maps_manager, + "train", + split, + self.config.validation.selection_metrics, + ) + self.validator._ensemble_prediction( + self.maps_manager, + "validation", + split, + self.config.validation.selection_metrics, + ) - self._erase_tmp(split) + self._erase_tmp(split) def _train_ssda( self, @@ -528,7 +519,6 @@ def _train_ssda( resume : bool (optional, default=False) If True, the job is resumed from checkpoint. """ - # train_transforms, all_transforms = self.config.transforms.get_transforms() split_manager = self.maps_manager._init_split_manager(split_list) split_manager_target_lab = self.maps_manager._init_split_manager(