diff --git a/clinicadl/predict/predict_manager.py b/clinicadl/predict/predict_manager.py index 4d6adf2fc..d6fa8d300 100644 --- a/clinicadl/predict/predict_manager.py +++ b/clinicadl/predict/predict_manager.py @@ -39,7 +39,7 @@ class PredictManager: def __init__(self, _config: Union[PredictConfig, InterpretConfig]) -> None: self.maps_manager = MapsManager(_config.maps_dir) self._config = _config - self.validator = Validator() + self.validator = Validator(**_config.model_dump()) def predict( self, @@ -48,49 +48,8 @@ def predict( """Performs the prediction task on a subset of caps_directory defined in a TSV file. Parameters ---------- - data_group : str - name of the data group tested. - caps_directory : Path (optional, default=None) - path to the CAPS folder. For more information please refer to - [clinica documentation](https://aramislab.paris.inria.fr/clinica/docs/public/latest/CAPS/Introduction/). - Default will load the value of an existing data group - tsv_path : Path (optional, default=None) - path to a TSV file containing the list of participants and sessions to test. - Default will load the DataFrame of an existing data group - split_list : List[int] (optional, default=None) - list of splits to test. Default perform prediction on all splits available. - selection_metrics : List[str] (optional, default=None) - list of selection metrics to test. - Default performs the prediction on all selection metrics available. - multi_cohort : bool (optional, default=False) - If True considers that tsv_path is the path to a multi-cohort TSV. - diagnoses : List[str] (optional, default=()) - List of diagnoses to load if tsv_path is a split_directory. - Default uses the same as in training step. - use_labels : bool (optional, default=True) - If True, the labels must exist in test meta-data and metrics are computed. - batch_size : int (optional, default=None) - If given, sets the value of batch_size, else use the same as in training step. - n_proc : int (optional, default=None) - If given, sets the value of num_workers, else use the same as in training step. - gpu : bool (optional, default=None) - If given, a new value for the device of the model will be computed. - amp : bool (optional, default=False) - If enabled, uses Automatic Mixed Precision (requires GPU usage). - overwrite : bool (optional, default=False) - If True erase the occurrences of data_group. - label : str (optional, default=None) - Target label used for training (if network_task in [`regression`, `classification`]). - label_code : Optional[Dict[str, int]] (optional, default="default") - dictionary linking the target values to a node number. - save_tensor : bool (optional, default=False) - If true, save the tensor predicted for reconstruction task - save_nifti : bool (optional, default=False) - If true, save the nifti associated to the prediction for reconstruction task. - save_latent_tensor : bool (optional, default=False) - If true, save the tensor from the latent space for reconstruction task. - skip_leak_check : bool (optional, default=False) - If true, skip the leak check (not recommended). + label_code : + Examples -------- >>> _input_ @@ -189,8 +148,6 @@ def predict( self.maps_manager, self._config.data_group, split, - self._config.selection_metrics, - self._config.use_labels, self._config.skip_leak_check, ) @@ -437,8 +394,8 @@ def _predict_single( data_test, self._config.data_group, split, - self._config.selection_metrics, - gpu=self._config.gpu, + # self._config.selection_metrics, + # gpu=self._config.gpu, ) if self._config.save_nifti: self._compute_output_nifti( diff --git a/clinicadl/trainer/trainer.py b/clinicadl/trainer/trainer.py index 0bcb2e3e4..e2d4076e8 100644 --- a/clinicadl/trainer/trainer.py +++ b/clinicadl/trainer/trainer.py @@ -226,8 +226,8 @@ def train( n_classes=self.maps_manager.n_classes, network_task=self.maps_manager.network_task, amp=self.maps_manager.std_amp, - use_labels=use_labels, - report_ci=report_ci, + use_labels=self.maps_manager.use_labels, + report_ci=self.maps_manager.report_ci, selection_metrics=self.config.validation.selection_metrics, )