From 765f29a69b44aef39dccd25f42f02b00a81256c3 Mon Sep 17 00:00:00 2001 From: camillebrianceau Date: Thu, 3 Oct 2024 16:32:21 +0200 Subject: [PATCH] remove split_name parameters --- .../{cross_validation.py => split.py} | 0 .../commandline/pipelines/predict/cli.py | 4 +- .../pipelines/train/classification/cli.py | 6 +-- .../pipelines/train/from_json/cli.py | 4 +- .../pipelines/train/reconstruction/cli.py | 6 +-- .../pipelines/train/regression/cli.py | 6 +-- .../commandline/pipelines/train/resume/cli.py | 4 +- clinicadl/maps_manager/maps_manager.py | 47 ++++--------------- clinicadl/metrics/utils.py | 21 +++------ clinicadl/predict/predict_manager.py | 20 ++++---- clinicadl/predict/utils.py | 11 ++--- clinicadl/splitter/config.py | 2 +- clinicadl/splitter/split_utils.py | 27 +++++------ clinicadl/trainer/trainer.py | 46 ++++-------------- clinicadl/utils/meta_maps/getter.py | 4 +- clinicadl/validator/config.py | 1 - clinicadl/validator/validator.py | 10 ++-- tests/test_predict.py | 2 - 18 files changed, 69 insertions(+), 152 deletions(-) rename clinicadl/commandline/modules_options/{cross_validation.py => split.py} (100%) diff --git a/clinicadl/commandline/modules_options/cross_validation.py b/clinicadl/commandline/modules_options/split.py similarity index 100% rename from clinicadl/commandline/modules_options/cross_validation.py rename to clinicadl/commandline/modules_options/split.py diff --git a/clinicadl/commandline/pipelines/predict/cli.py b/clinicadl/commandline/pipelines/predict/cli.py index c4cdaf1a1..fa7303008 100644 --- a/clinicadl/commandline/pipelines/predict/cli.py +++ b/clinicadl/commandline/pipelines/predict/cli.py @@ -3,10 +3,10 @@ from clinicadl.commandline import arguments from clinicadl.commandline.modules_options import ( computational, - cross_validation, data, dataloader, maps_manager, + split, validation, ) from clinicadl.commandline.pipelines.predict import options @@ -29,7 +29,7 @@ @data.diagnoses @validation.skip_leak_check @validation.selection_metrics -@cross_validation.split +@split.split @computational.gpu @computational.amp @dataloader.n_proc diff --git a/clinicadl/commandline/pipelines/train/classification/cli.py b/clinicadl/commandline/pipelines/train/classification/cli.py index 6b3c7875d..6a7814255 100644 --- a/clinicadl/commandline/pipelines/train/classification/cli.py +++ b/clinicadl/commandline/pipelines/train/classification/cli.py @@ -4,7 +4,6 @@ from clinicadl.commandline.modules_options import ( callbacks, computational, - cross_validation, data, dataloader, early_stopping, @@ -13,6 +12,7 @@ optimization, optimizer, reproducibility, + split, ssda, transforms, validation, @@ -70,8 +70,8 @@ @ssda.tsv_target_unlab @ssda.preprocessing_json_target # Cross validation -@cross_validation.n_splits -@cross_validation.split +@split.n_splits +@split.split # Optimization @optimizer.optimizer @optimizer.weight_decay diff --git a/clinicadl/commandline/pipelines/train/from_json/cli.py b/clinicadl/commandline/pipelines/train/from_json/cli.py index ab613d330..c0130a9b9 100644 --- a/clinicadl/commandline/pipelines/train/from_json/cli.py +++ b/clinicadl/commandline/pipelines/train/from_json/cli.py @@ -4,7 +4,7 @@ from clinicadl.commandline import arguments from clinicadl.commandline.modules_options import ( - cross_validation, + split, ) from clinicadl.trainer.trainer import Trainer @@ -12,7 +12,7 @@ @click.command(name="from_json", no_args_is_help=True) @arguments.config_file @arguments.output_maps -@cross_validation.split +@split.split def cli(**kwargs): """ Replicate a deep learning training based on a previously created JSON file. diff --git a/clinicadl/commandline/pipelines/train/reconstruction/cli.py b/clinicadl/commandline/pipelines/train/reconstruction/cli.py index db79eb591..37bd50b41 100644 --- a/clinicadl/commandline/pipelines/train/reconstruction/cli.py +++ b/clinicadl/commandline/pipelines/train/reconstruction/cli.py @@ -4,7 +4,6 @@ from clinicadl.commandline.modules_options import ( callbacks, computational, - cross_validation, data, dataloader, early_stopping, @@ -13,6 +12,7 @@ optimization, optimizer, reproducibility, + split, ssda, transforms, validation, @@ -70,8 +70,8 @@ @ssda.tsv_target_unlab @ssda.preprocessing_json_target # Cross validation -@cross_validation.n_splits -@cross_validation.split +@split.n_splits +@split.split # Optimization @optimizer.optimizer @optimizer.weight_decay diff --git a/clinicadl/commandline/pipelines/train/regression/cli.py b/clinicadl/commandline/pipelines/train/regression/cli.py index c00e700ec..fbc48e5b9 100644 --- a/clinicadl/commandline/pipelines/train/regression/cli.py +++ b/clinicadl/commandline/pipelines/train/regression/cli.py @@ -4,7 +4,6 @@ from clinicadl.commandline.modules_options import ( callbacks, computational, - cross_validation, data, dataloader, early_stopping, @@ -13,6 +12,7 @@ optimization, optimizer, reproducibility, + split, ssda, transforms, validation, @@ -68,8 +68,8 @@ @ssda.tsv_target_unlab @ssda.preprocessing_json_target # Cross validation -@cross_validation.n_splits -@cross_validation.split +@split.n_splits +@split.split # Optimization @optimizer.optimizer @optimizer.weight_decay diff --git a/clinicadl/commandline/pipelines/train/resume/cli.py b/clinicadl/commandline/pipelines/train/resume/cli.py index 8734bf95d..1fc34a0f4 100644 --- a/clinicadl/commandline/pipelines/train/resume/cli.py +++ b/clinicadl/commandline/pipelines/train/resume/cli.py @@ -2,14 +2,14 @@ from clinicadl.commandline import arguments from clinicadl.commandline.modules_options import ( - cross_validation, + split, ) from clinicadl.trainer.trainer import Trainer @click.command(name="resume", no_args_is_help=True) @arguments.input_maps -@cross_validation.split +@split.split def cli(input_maps_directory, split): """Resume training job in specified maps. diff --git a/clinicadl/maps_manager/maps_manager.py b/clinicadl/maps_manager/maps_manager.py index b1c4ebeb6..839a0044c 100644 --- a/clinicadl/maps_manager/maps_manager.py +++ b/clinicadl/maps_manager/maps_manager.py @@ -105,17 +105,11 @@ def __init__( f"Please choose between classification, regression and reconstruction." ) - self.split_name = ( - self._check_split_wording() - ) # Used only for retro-compatibility - # Initiate MAPS else: - print(parameters) self._check_args(parameters) parameters["tsv_path"] = Path(parameters["tsv_path"]) - self.split_name = "split" # Used only for retro-compatibility if cluster.master: if (maps_path.is_dir() and maps_path.is_file()) or ( # Non-folder file maps_path.is_dir() and list(maps_path.iterdir()) # Non empty folder @@ -326,12 +320,7 @@ def _write_train_val_groups(self): for split in split_manager.split_iterator(): for data_group in ["train", "validation"]: df = split_manager[split][data_group] - group_path = ( - self.maps_path - / "groups" - / data_group - / f"{self.split_name}-{split}" - ) + group_path = self.maps_path / "groups" / data_group / f"split-{split}" group_path.mkdir(parents=True, exist_ok=True) columns = ["participant_id", "session_id", "cohort"] @@ -422,10 +411,7 @@ def _mode_level_to_tsv( data_group: the name referring to the data group on which evaluation is performed. """ performance_dir = ( - self.maps_path - / f"{self.split_name}-{split}" - / f"best-{selection}" - / data_group + self.maps_path / f"split-{split}" / f"best-{selection}" / data_group ) performance_dir.mkdir(parents=True, exist_ok=True) performance_path = ( @@ -482,7 +468,6 @@ def _ensemble_to_tsv( validation_dataset = "validation" test_df = get_prediction( self.maps_path, - self.split_name, data_group, split, selection, @@ -491,7 +476,6 @@ def _ensemble_to_tsv( ) validation_df = get_prediction( self.maps_path, - self.split_name, validation_dataset, split, selection, @@ -500,10 +484,7 @@ def _ensemble_to_tsv( ) performance_dir = ( - self.maps_path - / f"{self.split_name}-{split}" - / f"best-{selection}" - / data_group + self.maps_path / f"split-{split}" / f"best-{selection}" / data_group ) performance_dir.mkdir(parents=True, exist_ok=True) @@ -551,7 +532,6 @@ def _mode_to_image_tsv( """ sub_df = get_prediction( self.maps_path, - self.split_name, data_group, split, selection, @@ -561,10 +541,7 @@ def _mode_to_image_tsv( sub_df.rename(columns={f"{self.mode}_id": "image_id"}, inplace=True) performance_dir = ( - self.maps_path - / f"{self.split_name}-{split}" - / f"best-{selection}" - / data_group + self.maps_path / f"split-{split}" / f"best-{selection}" / data_group ) sub_df.to_csv( performance_dir / f"{data_group}_image_level_prediction.tsv", @@ -638,10 +615,7 @@ def _init_model( if resume: checkpoint_path = ( - self.maps_path - / f"{self.split_name}-{split}" - / "tmp" - / "checkpoint.pth.tar" + self.maps_path / f"split-{split}" / "tmp" / "checkpoint.pth.tar" ) checkpoint_state = torch.load( checkpoint_path, map_location=device, weights_only=True @@ -694,10 +668,7 @@ def _print_description_log( selection_metric (str): Metric used for best weights selection. """ log_dir = ( - self.maps_path - / f"{self.split_name}-{split}" - / f"best-{selection_metric}" - / data_group + self.maps_path / f"split-{split}" / f"best-{selection_metric}" / data_group ) log_path = log_dir / "description.log" with log_path.open(mode="r") as f: @@ -767,7 +738,7 @@ def get_state_dict( (Dict): dictionary of results (weights, epoch number, metrics values) """ selection_metric = check_selection_metric( - self.maps_path, self.split_name, split, selection_metric + self.maps_path, split, selection_metric ) if self.multi_network: if network is None: @@ -777,14 +748,14 @@ def get_state_dict( else: model_path = ( self.maps_path - / f"{self.split_name}-{split}" + / f"split-{split}" / f"best-{selection_metric}" / f"network-{network}_model.pth.tar" ) else: model_path = ( self.maps_path - / f"{self.split_name}-{split}" + / f"split-{split}" / f"best-{selection_metric}" / "model.pth.tar" ) diff --git a/clinicadl/metrics/utils.py b/clinicadl/metrics/utils.py index c39cf80f8..dec32c524 100644 --- a/clinicadl/metrics/utils.py +++ b/clinicadl/metrics/utils.py @@ -7,10 +7,10 @@ from clinicadl.utils.exceptions import ClinicaDLArgumentError, MAPSError -def find_selection_metrics(maps_path: Path, split_name: str, split): +def find_selection_metrics(maps_path: Path, split): """Find which selection metrics are available in MAPS for a given split.""" - split_path = maps_path / f"{split_name}-{split}" + split_path = maps_path / f"split-{split}" if not split_path.is_dir(): raise KeyError( f"Training of split {split} was not performed." @@ -24,11 +24,9 @@ def find_selection_metrics(maps_path: Path, split_name: str, split): ] -def check_selection_metric( - maps_path: Path, split_name: str, split, selection_metric=None -): +def check_selection_metric(maps_path: Path, split, selection_metric=None): """Check that a given selection metric is available for a given split.""" - available_metrics = find_selection_metrics(maps_path, split_name, split) + available_metrics = find_selection_metrics(maps_path, split) if not selection_metric: if len(available_metrics) > 1: @@ -49,7 +47,6 @@ def check_selection_metric( def get_metrics( maps_path: Path, - split_name: str, data_group: str, split: int = 0, selection_metric: Optional[str] = None, @@ -68,15 +65,11 @@ def get_metrics( Returns: (dict[str:float]): Values of the metrics """ - selection_metric = check_selection_metric( - maps_path, split_name, split, selection_metric - ) + selection_metric = check_selection_metric(maps_path, split, selection_metric) if verbose: - print_description_log( - maps_path, split_name, data_group, split, selection_metric - ) + print_description_log(maps_path, data_group, split, selection_metric) prediction_dir = ( - maps_path / f"{split_name}-{split}" / f"best-{selection_metric}" / data_group + maps_path / f"split-{split}" / f"best-{selection_metric}" / data_group ) if not prediction_dir.is_dir(): raise MAPSError( diff --git a/clinicadl/predict/predict_manager.py b/clinicadl/predict/predict_manager.py index c197a96de..55515dc8e 100644 --- a/clinicadl/predict/predict_manager.py +++ b/clinicadl/predict/predict_manager.py @@ -148,7 +148,6 @@ def predict( if not self._config.selection_metrics: split_selection_metrics = find_selection_metrics( self.maps_manager.maps_path, - self.maps_manager.split_name, split, ) else: @@ -156,7 +155,7 @@ def predict( for selection in split_selection_metrics: tsv_dir = ( self.maps_manager.maps_path - / f"{self.maps_manager.split_name}-{split}" + / f"split-{split}" / f"best-{selection}" / self._config.data_group ) @@ -497,7 +496,7 @@ def _compute_latent_tensors( model.eval() tensor_path = ( self.maps_manager.maps_path - / f"{self.maps_manager.split_name}-{split}" + / f"split-{split}" / f"best-{selection_metric}" / self._config.data_group / "latent_tensors" @@ -574,7 +573,7 @@ def _compute_output_nifti( model.eval() nifti_path = ( self.maps_manager.maps_path - / f"{self.maps_manager.split_name}-{split}" + / f"split-{split}" / f"best-{selection_metric}" / self._config.data_group / "nifti_images" @@ -719,14 +718,13 @@ def interpret(self): if not self._config.selection_metrics: self._config.selection_metrics = find_selection_metrics( self.maps_manager.maps_path, - self.maps_manager.split_name, split, ) for selection_metric in self._config.selection_metrics: logger.info(f"Interpretation of metric {selection_metric}") results_path = ( self.maps_manager.maps_path - / f"{self.maps_manager.split_name}-{split}" + / f"split-{split}" / f"best-{selection_metric}" / self._config.data_group / f"interpret-{self._config.name}" @@ -843,13 +841,12 @@ def _check_data_group( for split in self._config.split: selection_metrics = find_selection_metrics( self.maps_manager.maps_path, - self.maps_manager.split_name, split, ) for selection in selection_metrics: results_path = ( self.maps_manager.maps_path - / f"{self.maps_manager.split_name}-{split}" + / f"split-{split}" / f"best-{selection}" / self._config.data_group ) @@ -927,12 +924,12 @@ def get_group_info( "Information on train or validation data can only be " "loaded if a split number is given" ) - elif not (group_path / f"{self.maps_manager.split_name}-{split}").is_dir(): + elif not (group_path / f"split-{split}").is_dir(): raise MAPSError( f"Split {split} is not available for data group {data_group}." ) else: - group_path = group_path / f"{self.maps_manager.split_name}-{split}" + group_path = group_path / f"split-{split}" df = pd.read_csv(group_path / "data.tsv", sep="\t") json_path = group_path / "maps.json" @@ -1054,7 +1051,6 @@ def get_interpretation( selection_metric = check_selection_metric( self.maps_manager.maps_path, - self.maps_manager.split_name, split, selection_metric, ) @@ -1064,7 +1060,7 @@ def get_interpretation( ) map_dir = ( self.maps_manager.maps_path - / f"{self.maps_manager.split_name}-{split}" + / f"split-{split}" / f"best-{selection_metric}" / data_group / f"interpret-{name}" diff --git a/clinicadl/predict/utils.py b/clinicadl/predict/utils.py index e547fb5b1..c66372764 100644 --- a/clinicadl/predict/utils.py +++ b/clinicadl/predict/utils.py @@ -10,7 +10,6 @@ def get_prediction( maps_path: Path, - split_name: str, data_group: str, split: int = 0, selection_metric: Optional[str] = None, @@ -31,15 +30,11 @@ def get_prediction( (DataFrame): Results indexed by columns 'participant_id' and 'session_id' which identifies the image in the BIDS / CAPS. """ - selection_metric = check_selection_metric( - maps_path, split_name, split, selection_metric - ) + selection_metric = check_selection_metric(maps_path, split, selection_metric) if verbose: - print_description_log( - maps_path, split_name, data_group, split, selection_metric - ) + print_description_log(maps_path, data_group, split, selection_metric) prediction_dir = ( - maps_path / f"{split_name}-{split}" / f"best-{selection_metric}" / data_group + maps_path / f"split-{split}" / f"best-{selection_metric}" / data_group ) if not prediction_dir.is_dir(): raise MAPSError( diff --git a/clinicadl/splitter/config.py b/clinicadl/splitter/config.py index 62e4fd931..8bf639f09 100644 --- a/clinicadl/splitter/config.py +++ b/clinicadl/splitter/config.py @@ -39,7 +39,7 @@ def adapt_cross_val_with_maps_manager_info( ): # maps_manager is of type MapsManager but need to be in a MapsConfig type in the future # TEMPORARY if not self.split: - self.split = find_splits(maps_manager.maps_path, maps_manager.split_name) + self.split = find_splits(maps_manager.maps_path) logger.debug(f"List of splits {self.split}") diff --git a/clinicadl/splitter/split_utils.py b/clinicadl/splitter/split_utils.py index 6465029a2..1bf5ca457 100644 --- a/clinicadl/splitter/split_utils.py +++ b/clinicadl/splitter/split_utils.py @@ -2,40 +2,38 @@ from typing import List, Optional -def find_splits(maps_path: Path, split_name: str) -> List[int]: +def find_splits(maps_path: Path) -> List[int]: """Find which splits that were trained in the MAPS.""" splits = [ int(split.name.split("-")[1]) for split in list(maps_path.iterdir()) - if split.name.startswith(f"{split_name}-") + if split.name.startswith(f"split-") ] return splits -def find_stopped_splits(maps_path: Path, split_name: str) -> List[int]: +def find_stopped_splits(maps_path: Path) -> List[int]: """Find which splits for which training was not completed.""" - existing_split_list = find_splits(maps_path, split_name) + existing_split_list = find_splits(maps_path) stopped_splits = [ split for split in existing_split_list - if (maps_path / f"{split_name}-{split}" / "tmp") - in list((maps_path / f"{split_name}-{split}").iterdir()) + if (maps_path / f"split-{split}" / "tmp") + in list((maps_path / f"split-{split}").iterdir()) ] return stopped_splits -def find_finished_splits(maps_path: Path, split_name: str) -> List[int]: +def find_finished_splits(maps_path: Path) -> List[int]: """Find which splits for which training was completed.""" finished_splits = list() - existing_split_list = find_splits(maps_path, split_name) - stopped_splits = find_stopped_splits(maps_path, split_name) + existing_split_list = find_splits(maps_path) + stopped_splits = find_stopped_splits(maps_path) for split in existing_split_list: if split not in stopped_splits: performance_dir_list = [ performance_dir - for performance_dir in list( - (maps_path / f"{split_name}-{split}").iterdir() - ) + for performance_dir in list((maps_path / f"split-{split}").iterdir()) if "best-" in performance_dir.name ] if len(performance_dir_list) > 0: @@ -45,7 +43,6 @@ def find_finished_splits(maps_path: Path, split_name: str) -> List[int]: def print_description_log( maps_path: Path, - split_name: str, data_group: str, split: int, selection_metric: str, @@ -58,9 +55,7 @@ def print_description_log( split (int): Index of the split used for training. selection_metric (str): Metric used for best weights selection. """ - log_dir = ( - maps_path / f"{split_name}-{split}" / f"best-{selection_metric}" / data_group - ) + log_dir = maps_path / f"split-{split}" / f"best-{selection_metric}" / data_group log_path = log_dir / "description.log" with log_path.open(mode="r") as f: content = f.read() diff --git a/clinicadl/trainer/trainer.py b/clinicadl/trainer/trainer.py index c7b071514..ec3e99eab 100644 --- a/clinicadl/trainer/trainer.py +++ b/clinicadl/trainer/trainer.py @@ -155,16 +155,8 @@ def resume(self, splits: List[int]) -> None: splits : List[int] The splits that must be resumed. """ - stopped_splits = set( - find_stopped_splits( - self.config.maps_manager.maps_dir, self.maps_manager.split_name - ) - ) - finished_splits = set( - find_finished_splits( - self.maps_manager.maps_path, self.maps_manager.split_name - ) - ) + stopped_splits = set(find_stopped_splits(self.config.maps_manager.maps_dir)) + finished_splits = set(find_finished_splits(self.maps_manager.maps_path)) # TODO : check these two lines. Why do we need a split_manager? split_manager = init_splitter( parameters=self.config.get_dict(), @@ -254,9 +246,7 @@ def check_split_list(self, split_list, overwrite): split_list=split_list, ) for split in split_manager.split_iterator(): - split_path = ( - self.maps_manager.maps_path / f"{self.maps_manager.split_name}-{split}" - ) + split_path = self.maps_manager.maps_path / f"split-{split}" if split_path.is_dir(): if overwrite: if cluster.master: @@ -295,11 +285,7 @@ def _resume( split_list=split_list, ) for split in split_manager.split_iterator(): - if not ( - self.maps_manager.maps_path - / f"{self.maps_manager.split_name}-{split}" - / "tmp" - ).is_dir(): + if not (self.maps_manager.maps_path / f"split-{split}" / "tmp").is_dir(): missing_splits.append(split) if len(missing_splits) > 0: @@ -336,9 +322,7 @@ def init_first_network(self, resume, split): int(network_folder.split("-")[1]) for network_folder in list( ( - self.maps_manager.maps_path - / f"{self.maps_manager.split_name}-{split}" - / "training_logs" + self.maps_manager.maps_path / f"split-{split}" / "training_logs" ).iterdir() ) ] @@ -1506,7 +1490,7 @@ def _init_optimizer( if resume: checkpoint_path = ( self.maps_manager.maps_path - / f"{self.maps_manager.split_name}-{split}" + / f"split-{split}" / "tmp" / "optimizer.pth.tar" ) @@ -1562,11 +1546,7 @@ def _erase_tmp(self, split: int): split : int The split on which the model has been trained. """ - tmp_path = ( - self.maps_manager.maps_path - / f"{self.maps_manager.split_name}-{split}" - / "tmp" - ) + tmp_path = self.maps_manager.maps_path / f"split-{split}" / "tmp" shutil.rmtree(tmp_path) def _write_weights( @@ -1599,20 +1579,14 @@ def _write_weights( Whether to save model weights at every epoch. If False, only the best model will be saved. """ - checkpoint_dir = ( - self.maps_manager.maps_path - / f"{self.maps_manager.split_name}-{split}" - / "tmp" - ) + checkpoint_dir = self.maps_manager.maps_path / f"split-{split}" / "tmp" checkpoint_dir.mkdir(parents=True, exist_ok=True) checkpoint_path = checkpoint_dir / filename torch.save(state, checkpoint_path) if save_all_models: all_models_dir = ( - self.maps_manager.maps_path - / f"{self.maps_manager.split_name}-{split}" - / "all_models" + self.maps_manager.maps_path / f"split-{split}" / "all_models" ) all_models_dir.mkdir(parents=True, exist_ok=True) torch.save(state, all_models_dir / f"model_epoch_{state['epoch']}.pth.tar") @@ -1626,7 +1600,7 @@ def _write_weights( for metric_name, metric_bool in metrics_dict.items(): metric_path = ( self.maps_manager.maps_path - / f"{self.maps_manager.split_name}-{split}" + / f"split-{split}" / f"best-{metric_name}" ) if metric_bool: diff --git a/clinicadl/utils/meta_maps/getter.py b/clinicadl/utils/meta_maps/getter.py index 6ea53eac9..1fa524950 100644 --- a/clinicadl/utils/meta_maps/getter.py +++ b/clinicadl/utils/meta_maps/getter.py @@ -36,20 +36,18 @@ def meta_maps_analysis(launch_dir: Path, evaluation_metric="loss"): for job in jobs_list: performances_dict[job] = dict() maps_manager = MapsManager(launch_dir / job) - split_list = find_splits(maps_manager.maps_path, maps_manager.split_name) + split_list = find_splits(maps_manager.maps_path) split_set = split_set | set(split_list) for split in split_set: performances_dict[job][split] = dict() selection_metrics = find_selection_metrics( maps_manager.maps_path, - maps_manager.split_name, split, ) selection_set = selection_set | set(selection_metrics) for metric in selection_metrics: validation_metrics = get_metrics( maps_manager.maps_path, - maps_manager.split_name, "validation", split, metric, diff --git a/clinicadl/validator/config.py b/clinicadl/validator/config.py index 165b36dd0..2f8c8a30a 100644 --- a/clinicadl/validator/config.py +++ b/clinicadl/validator/config.py @@ -18,7 +18,6 @@ class ValidatorConfig(BaseModel): maps_path: Path mode: str network_task: str - split_name: Optional[str] = None num_networks: Optional[int] = None fsdp: Optional[bool] = None amp: Optional[bool] = None diff --git a/clinicadl/validator/validator.py b/clinicadl/validator/validator.py index d55810299..c8f5e9451 100644 --- a/clinicadl/validator/validator.py +++ b/clinicadl/validator/validator.py @@ -249,7 +249,7 @@ def _test_loader( if cluster.master: log_dir = ( maps_manager.maps_path - / f"{maps_manager.split_name}-{split}" + / f"split-{split}" / f"best-{selection_metric}" / data_group ) @@ -339,7 +339,7 @@ def _test_loader_ssda( for selection_metric in selection_metrics: log_dir = ( maps_manager.maps_path - / f"{maps_manager.split_name}-{split}" + / f"split-{split}" / f"best-{selection_metric}" / data_group ) @@ -430,7 +430,7 @@ def _compute_output_tensors( tensor_path = ( maps_manager.maps_path - / f"{maps_manager.split_name}-{split}" + / f"split-{split}" / f"best-{selection_metric}" / data_group / "tensors" @@ -475,9 +475,7 @@ def _ensemble_prediction( """Computes the results on the image-level.""" if not selection_metrics: - selection_metrics = find_selection_metrics( - maps_manager.maps_path, maps_manager.split_name, split - ) + selection_metrics = find_selection_metrics(maps_manager.maps_path, split) for selection_metric in selection_metrics: ##################### diff --git a/tests/test_predict.py b/tests/test_predict.py index 2c26a4a3e..849f0e20d 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -119,14 +119,12 @@ def test_predict(cmdopt, tmp_path, test_name): for mode in modes: get_prediction( predict_manager.maps_manager.maps_path, - predict_manager.maps_manager.split_name, data_group="test-RANDOM", mode=mode, ) if use_labels: get_metrics( predict_manager.maps_manager.maps_path, - predict_manager.maps_manager.split_name, data_group="test-RANDOM", mode=mode, )