Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
camillebrianceau committed Oct 1, 2024
1 parent cf5458d commit d34f4ba
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 50 deletions.
53 changes: 5 additions & 48 deletions clinicadl/predict/predict_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class PredictManager:
def __init__(self, _config: Union[PredictConfig, InterpretConfig]) -> None:
self.maps_manager = MapsManager(_config.maps_dir)
self._config = _config
self.validator = Validator()
self.validator = Validator(**_config.model_dump())

def predict(
self,
Expand All @@ -48,49 +48,8 @@ def predict(
"""Performs the prediction task on a subset of caps_directory defined in a TSV file.
Parameters
----------
data_group : str
name of the data group tested.
caps_directory : Path (optional, default=None)
path to the CAPS folder. For more information please refer to
[clinica documentation](https://aramislab.paris.inria.fr/clinica/docs/public/latest/CAPS/Introduction/).
Default will load the value of an existing data group
tsv_path : Path (optional, default=None)
path to a TSV file containing the list of participants and sessions to test.
Default will load the DataFrame of an existing data group
split_list : List[int] (optional, default=None)
list of splits to test. Default perform prediction on all splits available.
selection_metrics : List[str] (optional, default=None)
list of selection metrics to test.
Default performs the prediction on all selection metrics available.
multi_cohort : bool (optional, default=False)
If True considers that tsv_path is the path to a multi-cohort TSV.
diagnoses : List[str] (optional, default=())
List of diagnoses to load if tsv_path is a split_directory.
Default uses the same as in training step.
use_labels : bool (optional, default=True)
If True, the labels must exist in test meta-data and metrics are computed.
batch_size : int (optional, default=None)
If given, sets the value of batch_size, else use the same as in training step.
n_proc : int (optional, default=None)
If given, sets the value of num_workers, else use the same as in training step.
gpu : bool (optional, default=None)
If given, a new value for the device of the model will be computed.
amp : bool (optional, default=False)
If enabled, uses Automatic Mixed Precision (requires GPU usage).
overwrite : bool (optional, default=False)
If True erase the occurrences of data_group.
label : str (optional, default=None)
Target label used for training (if network_task in [`regression`, `classification`]).
label_code : Optional[Dict[str, int]] (optional, default="default")
dictionary linking the target values to a node number.
save_tensor : bool (optional, default=False)
If true, save the tensor predicted for reconstruction task
save_nifti : bool (optional, default=False)
If true, save the nifti associated to the prediction for reconstruction task.
save_latent_tensor : bool (optional, default=False)
If true, save the tensor from the latent space for reconstruction task.
skip_leak_check : bool (optional, default=False)
If true, skip the leak check (not recommended).
label_code :
Examples
--------
>>> _input_
Expand Down Expand Up @@ -189,8 +148,6 @@ def predict(
self.maps_manager,
self._config.data_group,
split,
self._config.selection_metrics,
self._config.use_labels,
self._config.skip_leak_check,
)

Expand Down Expand Up @@ -437,8 +394,8 @@ def _predict_single(
data_test,
self._config.data_group,
split,
self._config.selection_metrics,
gpu=self._config.gpu,
# self._config.selection_metrics,
# gpu=self._config.gpu,
)
if self._config.save_nifti:
self._compute_output_nifti(
Expand Down
4 changes: 2 additions & 2 deletions clinicadl/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,8 @@ def train(
n_classes=self.maps_manager.n_classes,
network_task=self.maps_manager.network_task,
amp=self.maps_manager.std_amp,
use_labels=use_labels,
report_ci=report_ci,
use_labels=self.maps_manager.use_labels,
report_ci=self.maps_manager.report_ci,
selection_metrics=self.config.validation.selection_metrics,
)

Expand Down

0 comments on commit d34f4ba

Please sign in to comment.