From 64a895858b3d505dfe8f5715ce3230a929cca927 Mon Sep 17 00:00:00 2001 From: camillebrianceau <57992134+camillebrianceau@users.noreply.github.com> Date: Wed, 26 Jun 2024 10:34:22 +0200 Subject: [PATCH] Some cleaning with utils files (#633) * clean Co-authored-by: thibaultdvx <154365476+thibaultdvx@users.noreply.github.com> --- clinicadl/caps_dataset/caps_dataset_config.py | 6 +- clinicadl/caps_dataset/caps_dataset_utils.py | 143 ++++++++++++- clinicadl/caps_dataset/data.py | 9 +- clinicadl/caps_dataset/data_config.py | 6 +- clinicadl/caps_dataset/extraction/config.py | 2 +- clinicadl/caps_dataset/preprocessing/utils.py | 147 ++++++++++++++ .../modules_options/computational.py | 6 +- .../modules_options/cross_validation.py | 2 +- .../modules_options/early_stopping.py | 2 +- .../modules_options/maps_manager.py | 2 +- .../commandline/modules_options/validation.py | 2 +- .../pipelines/generate/artifacts/cli.py | 13 +- .../pipelines/generate/hypometabolic/cli.py | 8 +- .../pipelines/generate/random/cli.py | 7 +- .../pipelines/generate/shepplogan/cli.py | 7 +- .../pipelines/generate/trivial/cli.py | 11 +- .../prepare_data/prepare_data_cli.py | 2 - .../pipelines/quality_check/pet_linear/cli.py | 1 - .../pipelines/quality_check/t1_linear/cli.py | 8 +- .../pipelines/train/classification/cli.py | 2 +- .../pipelines/train/list_models/cli.py | 2 +- .../pipelines/train/reconstruction/cli.py | 2 +- .../pipelines/train/regression/cli.py | 2 +- .../pipelines/transfer_learning/options.py | 2 +- clinicadl/config/config/ssda.py | 2 +- clinicadl/generate/generate_utils.py | 41 +--- clinicadl/hugging_face/hugging_face.py | 4 +- clinicadl/interpret/config.py | 8 +- .../pipelines => maps_manager}/__init__.py | 0 .../config.py} | 0 .../{utils => }/maps_manager/maps_manager.py | 15 +- clinicadl/predict/config.py | 9 +- clinicadl/predict/predict_manager.py | 6 +- clinicadl/prepare_data/prepare_data.py | 8 +- clinicadl/prepare_data/prepare_data_utils.py | 2 +- .../quality_check/pet_linear/quality_check.py | 5 +- .../quality_check/t1_linear/quality_check.py | 4 +- clinicadl/quality_check/t1_linear/utils.py | 5 +- clinicadl/quality_check/t1_volume/utils.py | 2 +- .../random_search/random_search_utils.py | 4 +- clinicadl/train/__init__.py | 0 clinicadl/trainer/config/classification.py | 2 +- clinicadl/trainer/config/reconstruction.py | 2 +- clinicadl/trainer/config/regression.py | 2 +- clinicadl/trainer/config/train.py | 14 +- clinicadl/{train => trainer}/tasks_utils.py | 0 clinicadl/trainer/trainer.py | 19 +- .../config => trainer}/transfer_learning.py | 0 clinicadl/tsvtools/get_labels/get_labels.py | 2 +- clinicadl/tsvtools/getlabels/getlabels.py | 2 +- clinicadl/tsvtools/kfold/kfold.py | 2 +- clinicadl/tsvtools/split/split.py | 2 +- .../{maps_manager => }/cluster/__init__.py | 0 .../cluster/api/__init__.py | 0 .../cluster/api/auto_master_addr_port.py | 0 .../{maps_manager => }/cluster/api/base.py | 0 .../{maps_manager => }/cluster/api/default.py | 0 .../{maps_manager => }/cluster/api/slurm.py | 0 .../cluster/api/torchelastic.py | 0 .../{maps_manager => }/cluster/config.py | 0 .../{maps_manager => }/cluster/interface.py | 0 .../utils/{maps_manager => }/cluster/utils.py | 0 .../computational}/computational.py | 0 .../{maps_manager => computational}/ddp.py | 2 +- .../early_stopping/config.py} | 0 .../{ => early_stopping}/early_stopping.py | 0 clinicadl/utils/iotools/__init__.py | 42 ++++ .../utils/{ => iotools}/clinica_utils.py | 148 -------------- .../iotools}/data_utils.py | 1 - .../{maps_manager => iotools}/iotools.py | 0 clinicadl/utils/iotools/maps_manager_utils.py | 75 +++++++ clinicadl/utils/{ => iotools}/read_utils.py | 6 +- .../utils.py => utils/iotools/train_utils.py} | 7 +- .../iotools}/trainer_utils.py | 0 .../extraction => utils/iotools}/utils.py | 0 .../utils/{maps_manager => }/logwriter.py | 0 clinicadl/utils/maps_manager/__init__.py | 7 - .../maps_manager/cluster/profiler/__init__.py | 3 - .../cluster/profiler/patch_kineto.py | 86 -------- .../utils/maps_manager/maps_manager_utils.py | 190 ------------------ clinicadl/utils/meta_maps/getter.py | 2 +- clinicadl/utils/task_manager/task_manager.py | 2 +- .../config => validation}/cross_validation.py | 2 +- .../split_manager/__init__.py | 0 .../split_manager/kfold.py | 2 +- .../split_manager/single_split.py | 2 +- .../split_manager/split_manager.py | 2 +- .../config => validation}/validation.py | 0 tests/test_resume.py | 2 +- tests/unittests/train/test_utils.py | 6 +- .../train/trainer/test_training_config.py | 4 +- .../unittests/utils/caps_dataset/test_data.py | 4 +- tests/unittests/utils/test_clinica_utils.py | 3 +- 93 files changed, 537 insertions(+), 617 deletions(-) create mode 100644 clinicadl/caps_dataset/preprocessing/utils.py rename clinicadl/{config/config/pipelines => maps_manager}/__init__.py (100%) rename clinicadl/{config/config/maps_manager.py => maps_manager/config.py} (100%) rename clinicadl/{utils => }/maps_manager/maps_manager.py (98%) delete mode 100755 clinicadl/train/__init__.py rename clinicadl/{train => trainer}/tasks_utils.py (100%) rename clinicadl/{config/config => trainer}/transfer_learning.py (100%) rename clinicadl/utils/{maps_manager => }/cluster/__init__.py (100%) rename clinicadl/utils/{maps_manager => }/cluster/api/__init__.py (100%) rename clinicadl/utils/{maps_manager => }/cluster/api/auto_master_addr_port.py (100%) rename clinicadl/utils/{maps_manager => }/cluster/api/base.py (100%) rename clinicadl/utils/{maps_manager => }/cluster/api/default.py (100%) rename clinicadl/utils/{maps_manager => }/cluster/api/slurm.py (100%) rename clinicadl/utils/{maps_manager => }/cluster/api/torchelastic.py (100%) rename clinicadl/utils/{maps_manager => }/cluster/config.py (100%) rename clinicadl/utils/{maps_manager => }/cluster/interface.py (100%) rename clinicadl/utils/{maps_manager => }/cluster/utils.py (100%) rename clinicadl/{config/config => utils/computational}/computational.py (100%) rename clinicadl/utils/{maps_manager => computational}/ddp.py (99%) rename clinicadl/{config/config/early_stopping.py => utils/early_stopping/config.py} (100%) rename clinicadl/utils/{ => early_stopping}/early_stopping.py (100%) create mode 100644 clinicadl/utils/iotools/__init__.py rename clinicadl/utils/{ => iotools}/clinica_utils.py (86%) rename clinicadl/{caps_dataset => utils/iotools}/data_utils.py (99%) rename clinicadl/utils/{maps_manager => iotools}/iotools.py (100%) create mode 100644 clinicadl/utils/iotools/maps_manager_utils.py rename clinicadl/utils/{ => iotools}/read_utils.py (98%) rename clinicadl/{train/utils.py => utils/iotools/train_utils.py} (97%) rename clinicadl/{trainer => utils/iotools}/trainer_utils.py (100%) rename clinicadl/{caps_dataset/extraction => utils/iotools}/utils.py (100%) rename clinicadl/utils/{maps_manager => }/logwriter.py (100%) delete mode 100644 clinicadl/utils/maps_manager/__init__.py delete mode 100644 clinicadl/utils/maps_manager/cluster/profiler/__init__.py delete mode 100644 clinicadl/utils/maps_manager/cluster/profiler/patch_kineto.py delete mode 100644 clinicadl/utils/maps_manager/maps_manager_utils.py rename clinicadl/{config/config => validation}/cross_validation.py (94%) rename clinicadl/{utils => validation}/split_manager/__init__.py (100%) rename clinicadl/{utils => validation}/split_manager/kfold.py (94%) rename clinicadl/{utils => validation}/split_manager/single_split.py (92%) rename clinicadl/{utils => validation}/split_manager/split_manager.py (99%) rename clinicadl/{config/config => validation}/validation.py (100%) diff --git a/clinicadl/caps_dataset/caps_dataset_config.py b/clinicadl/caps_dataset/caps_dataset_config.py index 3eec87c54..b7086944c 100644 --- a/clinicadl/caps_dataset/caps_dataset_config.py +++ b/clinicadl/caps_dataset/caps_dataset_config.py @@ -14,15 +14,15 @@ PreprocessingConfig, T1PreprocessingConfig, ) -from clinicadl.transforms.config import TransformsConfig -from clinicadl.utils.clinica_utils import ( - FileType, +from clinicadl.caps_dataset.preprocessing.utils import ( bids_nii, dwi_dti, linear_nii, pet_linear_nii, ) +from clinicadl.transforms.config import TransformsConfig from clinicadl.utils.enum import ExtractionMethod, Preprocessing +from clinicadl.utils.iotools.clinica_utils import FileType def get_extraction(extract_method: ExtractionMethod): diff --git a/clinicadl/caps_dataset/caps_dataset_utils.py b/clinicadl/caps_dataset/caps_dataset_utils.py index d5244332c..b87c6ed22 100644 --- a/clinicadl/caps_dataset/caps_dataset_utils.py +++ b/clinicadl/caps_dataset/caps_dataset_utils.py @@ -1,5 +1,6 @@ +import json from pathlib import Path -from typing import Optional, Tuple +from typing import Any, Dict, Optional, Tuple from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig from clinicadl.caps_dataset.preprocessing.config import ( @@ -9,14 +10,15 @@ PETPreprocessingConfig, T1PreprocessingConfig, ) -from clinicadl.utils.clinica_utils import ( - FileType, +from clinicadl.caps_dataset.preprocessing.utils import ( bids_nii, dwi_dti, linear_nii, pet_linear_nii, ) from clinicadl.utils.enum import Preprocessing +from clinicadl.utils.exceptions import ClinicaDLArgumentError +from clinicadl.utils.iotools.clinica_utils import FileType def compute_folder_and_file_type( @@ -54,3 +56,138 @@ def compute_folder_and_file_type( description="Custom suffix", ) return mod_subfolder, file_type + + +def find_file_type(config: CapsDatasetConfig) -> FileType: + if isinstance(config.preprocessing, T1PreprocessingConfig): + file_type = linear_nii(config.preprocessing) + elif isinstance(config.preprocessing, PETPreprocessingConfig): + if ( + config.preprocessing.tracer is None + or config.preprocessing.suvr_reference_region is None + ): + raise ClinicaDLArgumentError( + "`tracer` and `suvr_reference_region` must be defined " + "when using `pet-linear` preprocessing." + ) + file_type = pet_linear_nii(config.preprocessing) + else: + raise NotImplementedError( + f"Generation of synthetic data is not implemented for preprocessing {config.preprocessing.preprocessing.value}" + ) + + return file_type + + +def read_json(json_path: Path) -> Dict[str, Any]: + """ + Ensures retro-compatibility between the different versions of ClinicaDL. + + Parameters + ---------- + json_path: Path + path to the JSON file summing the parameters of a MAPS. + + Returns + ------- + A dictionary of training parameters. + """ + from clinicadl.utils.iotools.utils import path_decoder + + with json_path.open(mode="r") as f: + parameters = json.load(f, object_hook=path_decoder) + # Types of retro-compatibility + # Change arg name: ex network --> model + # Change arg value: ex for preprocessing: mni --> t1-extensive + # New arg with default hard-coded value --> discarded_slice --> 20 + retro_change_name = { + "model": "architecture", + "multi": "multi_network", + "minmaxnormalization": "normalize", + "num_workers": "n_proc", + "mode": "extract_method", + } + + retro_add = { + "optimizer": "Adam", + "loss": None, + } + + for old_name, new_name in retro_change_name.items(): + if old_name in parameters: + parameters[new_name] = parameters[old_name] + del parameters[old_name] + + for name, value in retro_add.items(): + if name not in parameters: + parameters[name] = value + + if "extract_method" in parameters: + parameters["mode"] = parameters["extract_method"] + # Value changes + if "use_cpu" in parameters: + parameters["gpu"] = not parameters["use_cpu"] + del parameters["use_cpu"] + if "nondeterministic" in parameters: + parameters["deterministic"] = not parameters["nondeterministic"] + del parameters["nondeterministic"] + + # Build preprocessing_dict + if "preprocessing_dict" not in parameters: + parameters["preprocessing_dict"] = {"mode": parameters["mode"]} + preprocessing_options = [ + "preprocessing", + "use_uncropped_image", + "prepare_dl", + "custom_suffix", + "tracer", + "suvr_reference_region", + "patch_size", + "stride_size", + "slice_direction", + "slice_mode", + "discarded_slices", + "roi_list", + "uncropped_roi", + "roi_custom_suffix", + "roi_custom_template", + "roi_custom_mask_pattern", + ] + for preprocessing_var in preprocessing_options: + if preprocessing_var in parameters: + parameters["preprocessing_dict"][preprocessing_var] = parameters[ + preprocessing_var + ] + del parameters[preprocessing_var] + + # Add missing parameters in previous version of extract + if "use_uncropped_image" not in parameters["preprocessing_dict"]: + parameters["preprocessing_dict"]["use_uncropped_image"] = False + + if ( + "prepare_dl" not in parameters["preprocessing_dict"] + and parameters["mode"] != "image" + ): + parameters["preprocessing_dict"]["prepare_dl"] = False + + if ( + parameters["mode"] == "slice" + and "slice_mode" not in parameters["preprocessing_dict"] + ): + parameters["preprocessing_dict"]["slice_mode"] = "rgb" + + if "preprocessing" not in parameters: + parameters["preprocessing"] = parameters["preprocessing_dict"]["preprocessing"] + + from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig + + config = CapsDatasetConfig.from_preprocessing_and_extraction_method( + extraction=parameters["mode"], + preprocessing_type=parameters["preprocessing"], + **parameters, + ) + if "file_type" not in parameters["preprocessing_dict"]: + _, file_type = compute_folder_and_file_type(config) + parameters["preprocessing_dict"]["file_type"] = file_type.model_dump() + + return parameters diff --git a/clinicadl/caps_dataset/data.py b/clinicadl/caps_dataset/data.py index 884d184a1..48d1a5480 100644 --- a/clinicadl/caps_dataset/data.py +++ b/clinicadl/caps_dataset/data.py @@ -3,7 +3,7 @@ import abc from logging import getLogger from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import pandas as pd @@ -11,7 +11,6 @@ from torch.utils.data import Dataset from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig -from clinicadl.caps_dataset.caps_dataset_utils import compute_folder_and_file_type from clinicadl.caps_dataset.extraction.config import ( ExtractionImageConfig, ExtractionPatchConfig, @@ -30,7 +29,6 @@ ) from clinicadl.transforms.config import TransformsConfig from clinicadl.utils.enum import ( - ExtractionMethod, Pattern, Preprocessing, SliceDirection, @@ -38,7 +36,6 @@ Template, ) from clinicadl.utils.exceptions import ( - ClinicaDLArgumentError, ClinicaDLCAPSError, ClinicaDLTSVError, ) @@ -133,7 +130,7 @@ def _get_image_path(self, participant: str, session: str, cohort: str) -> Path: Returns: image_path: path to the tensor containing the whole image. """ - from clinicadl.utils.clinica_utils import clinicadl_file_reader + from clinicadl.utils.iotools.clinica_utils import clinicadl_file_reader # Try to find .nii.gz file try: @@ -221,7 +218,7 @@ def _get_full_image(self) -> torch.Tensor: """ import nibabel as nib - from clinicadl.utils.clinica_utils import clinicadl_file_reader + from clinicadl.utils.iotools.clinica_utils import clinicadl_file_reader participant_id = self.df.loc[0, "participant_id"] session_id = self.df.loc[0, "session_id"] diff --git a/clinicadl/caps_dataset/data_config.py b/clinicadl/caps_dataset/data_config.py index 155146380..5fdeb568e 100644 --- a/clinicadl/caps_dataset/data_config.py +++ b/clinicadl/caps_dataset/data_config.py @@ -5,13 +5,13 @@ import pandas as pd from pydantic import BaseModel, ConfigDict, computed_field, field_validator -from clinicadl.caps_dataset.data_utils import check_multi_cohort_tsv, load_data_test -from clinicadl.caps_dataset.extraction.utils import read_preprocessing from clinicadl.utils.enum import Mode from clinicadl.utils.exceptions import ( ClinicaDLArgumentError, ClinicaDLTSVError, ) +from clinicadl.utils.iotools.data_utils import check_multi_cohort_tsv, load_data_test +from clinicadl.utils.iotools.utils import read_preprocessing logger = getLogger("clinicadl.data_config") @@ -85,7 +85,7 @@ def check_data_tsv(cls, v) -> Path: @computed_field @property def caps_dict(self) -> Dict[str, Path]: - from clinicadl.utils.clinica_utils import check_caps_folder + from clinicadl.utils.iotools.clinica_utils import check_caps_folder if self.multi_cohort: if self.caps_directory.suffix != ".tsv": diff --git a/clinicadl/caps_dataset/extraction/config.py b/clinicadl/caps_dataset/extraction/config.py index 5a88ca4fd..f3619590f 100644 --- a/clinicadl/caps_dataset/extraction/config.py +++ b/clinicadl/caps_dataset/extraction/config.py @@ -5,12 +5,12 @@ from pydantic import BaseModel, ConfigDict, field_validator from pydantic.types import NonNegativeInt -from clinicadl.utils.clinica_utils import FileType from clinicadl.utils.enum import ( ExtractionMethod, SliceDirection, SliceMode, ) +from clinicadl.utils.iotools.clinica_utils import FileType logger = getLogger("clinicadl.preprocessing_config") diff --git a/clinicadl/caps_dataset/preprocessing/utils.py b/clinicadl/caps_dataset/preprocessing/utils.py new file mode 100644 index 000000000..0aa93004d --- /dev/null +++ b/clinicadl/caps_dataset/preprocessing/utils.py @@ -0,0 +1,147 @@ +from typing import Optional + +from clinicadl.caps_dataset.preprocessing import config as preprocessing_config +from clinicadl.utils.enum import ( + LinearModality, + Preprocessing, + Tracer, +) +from clinicadl.utils.exceptions import ClinicaDLArgumentError +from clinicadl.utils.iotools.clinica_utils import FileType + + +def bids_nii( + config: preprocessing_config.PreprocessingConfig, + reconstruction: Optional[str] = None, +) -> FileType: + """Return the query dict required to capture PET scans. + + Parameters + ---------- + tracer : Tracer, optional + If specified, the query will only match PET scans acquired + with the requested tracer. + If None, the query will match all PET sans independently of + the tracer used. + + reconstruction : ReconstructionMethod, optional + If specified, the query will only match PET scans reconstructed + with the specified method. + If None, the query will match all PET scans independently of the + reconstruction method used. + + Returns + ------- + dict : + The query dictionary to get PET scans. + """ + + if config.preprocessing not in Preprocessing: + raise ClinicaDLArgumentError( + f"ClinicaDL is Unable to read this modality ({config.preprocessing}) of images, please chose one from this list: {list[Preprocessing]}" + ) + + if isinstance(config, preprocessing_config.PETPreprocessingConfig): + trc = "" if config.tracer is None else f"_trc-{Tracer(config.tracer).value}" + rec = "" if reconstruction is None else f"_rec-{reconstruction}" + description = "PET data" + + if config.tracer: + description += f" with {config.tracer.value} tracer" + if reconstruction: + description += f" and reconstruction method {reconstruction}" + + file_type = FileType( + pattern=f"pet/*{trc}{rec}_pet.nii*", description=description + ) + return file_type + + elif isinstance(config, preprocessing_config.T1PreprocessingConfig): + return FileType(pattern="anat/sub-*_ses-*_T1w.nii*", description="T1w MRI") + + elif isinstance(config, preprocessing_config.FlairPreprocessingConfig): + return FileType(pattern="sub-*_ses-*_flair.nii*", description="FLAIR T2w MRI") + + elif isinstance(config, preprocessing_config.DTIPreprocessingConfig): + return FileType(pattern="dwi/sub-*_ses-*_dwi.nii*", description="DWI NIfTI") + + else: + raise ClinicaDLArgumentError("Invalid preprocessing") + + +def linear_nii( + config: preprocessing_config.PreprocessingConfig, +) -> FileType: + if isinstance(config, preprocessing_config.T1PreprocessingConfig): + needed_pipeline = Preprocessing.T1_LINEAR + modality = LinearModality.T1W + elif isinstance(config, preprocessing_config.T2PreprocessingConfig): + needed_pipeline = Preprocessing.T2_LINEAR + modality = LinearModality.T2W + elif isinstance(config, preprocessing_config.FlairPreprocessingConfig): + needed_pipeline = Preprocessing.FLAIR_LINEAR + modality = LinearModality.FLAIR + else: + raise ClinicaDLArgumentError("Invalid configuration") + + if config.use_uncropped_image: + desc_crop = "" + else: + desc_crop = "_desc-Crop" + + file_type = FileType( + pattern=f"*space-MNI152NLin2009cSym{desc_crop}_res-1x1x1_{modality.value}.nii.gz", + description=f"{modality.value} Image registered in MNI152NLin2009cSym space using {needed_pipeline.value} pipeline " + + ( + "" + if config.use_uncropped_image + else "and cropped (matrix size 169×208×179, 1 mm isotropic voxels)" + ), + needed_pipeline=needed_pipeline, + ) + return file_type + + +def dwi_dti(config: preprocessing_config.DTIPreprocessingConfig) -> FileType: + """Return the query dict required to capture DWI DTI images. + + Parameters + ---------- + config: DTIPreprocessingConfig + + Returns + ------- + FileType : + """ + if isinstance(config, preprocessing_config.DTIPreprocessingConfig): + measure = config.dti_measure + space = config.dti_space + else: + raise ClinicaDLArgumentError( + f"PreprocessingConfig is of type {config} but should be of type{preprocessing_config.DTIPreprocessingConfig}" + ) + + return FileType( + pattern=f"dwi/dti_based_processing/*/*_space-{space}_{measure.value}.nii.gz", + description=f"DTI-based {measure.value} in space {space}.", + needed_pipeline="dwi_dti", + ) + + +def pet_linear_nii(config: preprocessing_config.PETPreprocessingConfig) -> FileType: + if not isinstance(config, preprocessing_config.PETPreprocessingConfig): + raise ClinicaDLArgumentError( + f"PreprocessingConfig is of type {config} but should be of type{preprocessing_config.PETPreprocessingConfig}" + ) + + if config.use_uncropped_image: + description = "" + else: + description = "_desc-Crop" + + file_type = FileType( + pattern=f"pet_linear/*_trc-{config.tracer.value}_space-MNI152NLin2009cSym{description}_res-1x1x1_suvr-{config.suvr_reference_region.value}_pet.nii.gz", + description="", + needed_pipeline="pet-linear", + ) + return file_type diff --git a/clinicadl/commandline/modules_options/computational.py b/clinicadl/commandline/modules_options/computational.py index 221d25f8a..5bc05e158 100644 --- a/clinicadl/commandline/modules_options/computational.py +++ b/clinicadl/commandline/modules_options/computational.py @@ -1,7 +1,7 @@ import click -from clinicadl.config.config.computational import ComputationalConfig from clinicadl.config.config_utils import get_default_from_config_class as get_default +from clinicadl.utils.computational.computational import ComputationalConfig # Computational amp = click.option( @@ -19,8 +19,8 @@ "this flag is already set to FSDP to that the zero flag is never actually removed.", ) gpu = click.option( - "--no-gpu", - is_flag=True, + "--gpu/--no-gpu", + default=get_default("gpu", ComputationalConfig), help="Use GPU by default. Please specify `--no-gpu` to force using CPU.", show_default=True, ) diff --git a/clinicadl/commandline/modules_options/cross_validation.py b/clinicadl/commandline/modules_options/cross_validation.py index 734c86dbb..c1c745ce3 100644 --- a/clinicadl/commandline/modules_options/cross_validation.py +++ b/clinicadl/commandline/modules_options/cross_validation.py @@ -1,8 +1,8 @@ import click -from clinicadl.config.config.cross_validation import CrossValidationConfig from clinicadl.config.config_utils import get_default_from_config_class as get_default from clinicadl.config.config_utils import get_type_from_config_class as get_type +from clinicadl.validation.cross_validation import CrossValidationConfig # Cross Validation n_splits = click.option( diff --git a/clinicadl/commandline/modules_options/early_stopping.py b/clinicadl/commandline/modules_options/early_stopping.py index 6385c3864..a41ab2a48 100644 --- a/clinicadl/commandline/modules_options/early_stopping.py +++ b/clinicadl/commandline/modules_options/early_stopping.py @@ -1,8 +1,8 @@ import click -from clinicadl.config.config.early_stopping import EarlyStoppingConfig from clinicadl.config.config_utils import get_default_from_config_class as get_default from clinicadl.config.config_utils import get_type_from_config_class as get_type +from clinicadl.utils.early_stopping.config import EarlyStoppingConfig # Early Stopping patience = click.option( diff --git a/clinicadl/commandline/modules_options/maps_manager.py b/clinicadl/commandline/modules_options/maps_manager.py index a9eeaea89..f973f441a 100644 --- a/clinicadl/commandline/modules_options/maps_manager.py +++ b/clinicadl/commandline/modules_options/maps_manager.py @@ -1,7 +1,7 @@ import click -from clinicadl.config.config.maps_manager import MapsManagerConfig from clinicadl.config.config_utils import get_type_from_config_class as get_type +from clinicadl.maps_manager.config import MapsManagerConfig maps_dir = click.argument("maps_dir", type=get_type("maps_dir", MapsManagerConfig)) data_group = click.option("data_group", type=get_type("data_group", MapsManagerConfig)) diff --git a/clinicadl/commandline/modules_options/validation.py b/clinicadl/commandline/modules_options/validation.py index 9f26d3311..4e2e973e3 100644 --- a/clinicadl/commandline/modules_options/validation.py +++ b/clinicadl/commandline/modules_options/validation.py @@ -1,8 +1,8 @@ import click -from clinicadl.config.config.validation import ValidationConfig from clinicadl.config.config_utils import get_default_from_config_class as get_default from clinicadl.config.config_utils import get_type_from_config_class as get_type +from clinicadl.validation.validation import ValidationConfig # Validation valid_longitudinal = click.option( diff --git a/clinicadl/commandline/pipelines/generate/artifacts/cli.py b/clinicadl/commandline/pipelines/generate/artifacts/cli.py index 1c89d4979..b4a98b40a 100644 --- a/clinicadl/commandline/pipelines/generate/artifacts/cli.py +++ b/clinicadl/commandline/pipelines/generate/artifacts/cli.py @@ -6,26 +6,24 @@ import torchio as tio from joblib import Parallel, delayed -from clinicadl.caps_dataset.caps_dataset_config import ( - CapsDatasetConfig, -) +from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig +from clinicadl.caps_dataset.caps_dataset_utils import find_file_type from clinicadl.commandline import arguments from clinicadl.commandline.modules_options import ( data, dataloader, - extraction, preprocessing, ) from clinicadl.commandline.pipelines.generate.artifacts import options as artifacts from clinicadl.generate.generate_config import GenerateArtifactsConfig from clinicadl.generate.generate_utils import ( - find_file_type, load_and_check_tsv, write_missing_mods, ) -from clinicadl.utils.clinica_utils import clinicadl_file_reader from clinicadl.utils.enum import ExtractionMethod -from clinicadl.utils.maps_manager.iotools import commandline_to_json +from clinicadl.utils.iotools.clinica_utils import clinicadl_file_reader +from clinicadl.utils.iotools.iotools import commandline_to_json +from clinicadl.utils.iotools.read_utils import get_info_from_filename logger = getLogger("clinicadl.generate.artifacts") @@ -103,7 +101,6 @@ def create_artifacts_image(data_idx: int) -> pd.DataFrame: file_type.model_dump(), )[0][0] ) - from clinicadl.utils.read_utils import get_info_from_filename ( subject_name, diff --git a/clinicadl/commandline/pipelines/generate/hypometabolic/cli.py b/clinicadl/commandline/pipelines/generate/hypometabolic/cli.py index b374d58a5..cb68269ca 100644 --- a/clinicadl/commandline/pipelines/generate/hypometabolic/cli.py +++ b/clinicadl/commandline/pipelines/generate/hypometabolic/cli.py @@ -8,6 +8,7 @@ from nilearn.image import resample_to_img from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig +from clinicadl.caps_dataset.caps_dataset_utils import find_file_type from clinicadl.commandline import arguments from clinicadl.commandline.modules_options import data, dataloader, preprocessing from clinicadl.commandline.pipelines.generate.hypometabolic import ( @@ -15,19 +16,18 @@ ) from clinicadl.generate.generate_config import GenerateHypometabolicConfig from clinicadl.generate.generate_utils import ( - find_file_type, load_and_check_tsv, mask_processing, write_missing_mods, ) from clinicadl.tsvtools.tsvtools_utils import extract_baseline -from clinicadl.utils.clinica_utils import clinicadl_file_reader from clinicadl.utils.enum import ( ExtractionMethod, Preprocessing, ) -from clinicadl.utils.maps_manager.iotools import commandline_to_json -from clinicadl.utils.read_utils import get_mask_path +from clinicadl.utils.iotools.clinica_utils import clinicadl_file_reader +from clinicadl.utils.iotools.iotools import commandline_to_json +from clinicadl.utils.iotools.read_utils import get_mask_path logger = getLogger("clinicadl.generate.hypometabolic") diff --git a/clinicadl/commandline/pipelines/generate/random/cli.py b/clinicadl/commandline/pipelines/generate/random/cli.py index b4c33a503..cf8e8d9e8 100644 --- a/clinicadl/commandline/pipelines/generate/random/cli.py +++ b/clinicadl/commandline/pipelines/generate/random/cli.py @@ -8,24 +8,23 @@ from joblib import Parallel, delayed from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig +from clinicadl.caps_dataset.caps_dataset_utils import find_file_type from clinicadl.commandline import arguments from clinicadl.commandline.modules_options import ( data, dataloader, - extraction, preprocessing, ) from clinicadl.commandline.pipelines.generate.random import options as random from clinicadl.generate.generate_config import GenerateRandomConfig from clinicadl.generate.generate_utils import ( - find_file_type, load_and_check_tsv, write_missing_mods, ) from clinicadl.tsvtools.tsvtools_utils import extract_baseline -from clinicadl.utils.clinica_utils import clinicadl_file_reader from clinicadl.utils.enum import ExtractionMethod -from clinicadl.utils.maps_manager.iotools import commandline_to_json +from clinicadl.utils.iotools.clinica_utils import clinicadl_file_reader +from clinicadl.utils.iotools.iotools import commandline_to_json logger = getLogger("clinicadl.generate.random") diff --git a/clinicadl/commandline/pipelines/generate/shepplogan/cli.py b/clinicadl/commandline/pipelines/generate/shepplogan/cli.py index 42595d35a..e9d20d9dc 100644 --- a/clinicadl/commandline/pipelines/generate/shepplogan/cli.py +++ b/clinicadl/commandline/pipelines/generate/shepplogan/cli.py @@ -1,5 +1,4 @@ from logging import getLogger -from pathlib import Path import click import numpy as np @@ -7,7 +6,6 @@ import torch from joblib import Parallel, delayed -from clinicadl.caps_dataset.extraction.utils import write_preprocessing from clinicadl.commandline import arguments from clinicadl.commandline.modules_options import data, dataloader from clinicadl.commandline.pipelines.generate.shepplogan import options as shepplogan @@ -16,8 +14,9 @@ generate_shepplogan_phantom, write_missing_mods, ) -from clinicadl.utils.clinica_utils import FileType -from clinicadl.utils.maps_manager.iotools import check_and_clean, commandline_to_json +from clinicadl.utils.iotools.clinica_utils import FileType +from clinicadl.utils.iotools.iotools import check_and_clean, commandline_to_json +from clinicadl.utils.iotools.utils import write_preprocessing logger = getLogger("clinicadl.generate.shepplogan") diff --git a/clinicadl/commandline/pipelines/generate/trivial/cli.py b/clinicadl/commandline/pipelines/generate/trivial/cli.py index 580de1626..b48651811 100644 --- a/clinicadl/commandline/pipelines/generate/trivial/cli.py +++ b/clinicadl/commandline/pipelines/generate/trivial/cli.py @@ -7,26 +7,25 @@ from joblib import Parallel, delayed from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig +from clinicadl.caps_dataset.caps_dataset_utils import find_file_type from clinicadl.commandline import arguments from clinicadl.commandline.modules_options import ( data, dataloader, - extraction, preprocessing, ) from clinicadl.commandline.pipelines.generate.trivial import options as trivial from clinicadl.generate.generate_config import GenerateTrivialConfig from clinicadl.generate.generate_utils import ( - find_file_type, im_loss_roi_gaussian_distribution, load_and_check_tsv, write_missing_mods, ) from clinicadl.tsvtools.tsvtools_utils import extract_baseline -from clinicadl.utils.clinica_utils import clinicadl_file_reader from clinicadl.utils.enum import ExtractionMethod -from clinicadl.utils.maps_manager.iotools import commandline_to_json -from clinicadl.utils.read_utils import get_mask_path +from clinicadl.utils.iotools.clinica_utils import clinicadl_file_reader +from clinicadl.utils.iotools.iotools import commandline_to_json +from clinicadl.utils.iotools.read_utils import get_mask_path logger = getLogger("clinicadl.generate.trivial") @@ -100,7 +99,7 @@ def create_trivial_image(subject_id: int) -> pd.DataFrame: )[0][0] ) - from clinicadl.utils.read_utils import get_info_from_filename + from clinicadl.utils.iotools.read_utils import get_info_from_filename _, _, filename_pattern, file_suffix = get_info_from_filename(image_path) diff --git a/clinicadl/commandline/pipelines/prepare_data/prepare_data_cli.py b/clinicadl/commandline/pipelines/prepare_data/prepare_data_cli.py index 35d6db73c..c9630c507 100644 --- a/clinicadl/commandline/pipelines/prepare_data/prepare_data_cli.py +++ b/clinicadl/commandline/pipelines/prepare_data/prepare_data_cli.py @@ -1,5 +1,3 @@ -from pathlib import Path - import click from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig diff --git a/clinicadl/commandline/pipelines/quality_check/pet_linear/cli.py b/clinicadl/commandline/pipelines/quality_check/pet_linear/cli.py index f6104aebb..455bb5299 100644 --- a/clinicadl/commandline/pipelines/quality_check/pet_linear/cli.py +++ b/clinicadl/commandline/pipelines/quality_check/pet_linear/cli.py @@ -4,7 +4,6 @@ from clinicadl.commandline.modules_options import ( data, dataloader, - extraction, preprocessing, ) from clinicadl.utils.enum import ExtractionMethod, Preprocessing diff --git a/clinicadl/commandline/pipelines/quality_check/t1_linear/cli.py b/clinicadl/commandline/pipelines/quality_check/t1_linear/cli.py index eff8900f8..f73971a63 100755 --- a/clinicadl/commandline/pipelines/quality_check/t1_linear/cli.py +++ b/clinicadl/commandline/pipelines/quality_check/t1_linear/cli.py @@ -1,11 +1,9 @@ -from pathlib import Path - import click from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig from clinicadl.commandline import arguments from clinicadl.commandline.modules_options import computational, data, dataloader -from clinicadl.config.config.computational import ComputationalConfig +from clinicadl.utils.computational.computational import ComputationalConfig from clinicadl.utils.enum import ExtractionMethod, Preprocessing @@ -44,7 +42,7 @@ def cli( threshold, batch_size, n_proc, - no_gpu, + gpu, amp, network, use_tensor, @@ -60,7 +58,7 @@ def cli( quality_check as linear_qc, ) - computational_config = ComputationalConfig(amp=amp, gpu=not no_gpu) + computational_config = ComputationalConfig(amp=amp, gpu=gpu) config = CapsDatasetConfig.from_preprocessing_and_extraction_method( caps_directory=caps_directory, extraction=ExtractionMethod.IMAGE, diff --git a/clinicadl/commandline/pipelines/train/classification/cli.py b/clinicadl/commandline/pipelines/train/classification/cli.py index 409c89c95..d552c318b 100644 --- a/clinicadl/commandline/pipelines/train/classification/cli.py +++ b/clinicadl/commandline/pipelines/train/classification/cli.py @@ -23,10 +23,10 @@ from clinicadl.commandline.pipelines.transfer_learning import ( options as transfer_learning, ) -from clinicadl.train.utils import merge_cli_and_config_file_options from clinicadl.trainer.config.classification import ClassificationConfig from clinicadl.trainer.trainer import Trainer from clinicadl.utils.enum import Task +from clinicadl.utils.iotools.train_utils import merge_cli_and_config_file_options @click.command(name="classification", no_args_is_help=True) diff --git a/clinicadl/commandline/pipelines/train/list_models/cli.py b/clinicadl/commandline/pipelines/train/list_models/cli.py index 2bdc1fe7e..95632aefc 100644 --- a/clinicadl/commandline/pipelines/train/list_models/cli.py +++ b/clinicadl/commandline/pipelines/train/list_models/cli.py @@ -28,6 +28,6 @@ def cli( model_layers, ): """Show the list of available models in ClinicaDL.""" - from clinicadl.train.utils import get_model_list + from clinicadl.utils.iotools.train_utils import get_model_list get_model_list(architecture, input_size, model_layers) diff --git a/clinicadl/commandline/pipelines/train/reconstruction/cli.py b/clinicadl/commandline/pipelines/train/reconstruction/cli.py index 33edf1d44..d0a40fa40 100644 --- a/clinicadl/commandline/pipelines/train/reconstruction/cli.py +++ b/clinicadl/commandline/pipelines/train/reconstruction/cli.py @@ -23,10 +23,10 @@ from clinicadl.commandline.pipelines.transfer_learning import ( options as transfer_learning, ) -from clinicadl.train.utils import merge_cli_and_config_file_options from clinicadl.trainer.config.reconstruction import ReconstructionConfig from clinicadl.trainer.trainer import Trainer from clinicadl.utils.enum import Task +from clinicadl.utils.iotools.train_utils import merge_cli_and_config_file_options @click.command(name="reconstruction", no_args_is_help=True) diff --git a/clinicadl/commandline/pipelines/train/regression/cli.py b/clinicadl/commandline/pipelines/train/regression/cli.py index 5ac1eb545..ffeb218a8 100644 --- a/clinicadl/commandline/pipelines/train/regression/cli.py +++ b/clinicadl/commandline/pipelines/train/regression/cli.py @@ -21,10 +21,10 @@ from clinicadl.commandline.pipelines.transfer_learning import ( options as transfer_learning, ) -from clinicadl.train.utils import merge_cli_and_config_file_options from clinicadl.trainer.config.regression import RegressionConfig from clinicadl.trainer.trainer import Trainer from clinicadl.utils.enum import Task +from clinicadl.utils.iotools.train_utils import merge_cli_and_config_file_options @click.command(name="regression", no_args_is_help=True) diff --git a/clinicadl/commandline/pipelines/transfer_learning/options.py b/clinicadl/commandline/pipelines/transfer_learning/options.py index 643b87caa..870f3e66b 100644 --- a/clinicadl/commandline/pipelines/transfer_learning/options.py +++ b/clinicadl/commandline/pipelines/transfer_learning/options.py @@ -1,8 +1,8 @@ import click -from clinicadl.config.config.transfer_learning import TransferLearningConfig from clinicadl.config.config_utils import get_default_from_config_class as get_default from clinicadl.config.config_utils import get_type_from_config_class as get_type +from clinicadl.trainer.transfer_learning import TransferLearningConfig nb_unfrozen_layer = click.option( "-nul", diff --git a/clinicadl/config/config/ssda.py b/clinicadl/config/config/ssda.py index 1c0c79eb6..caf52634d 100644 --- a/clinicadl/config/config/ssda.py +++ b/clinicadl/config/config/ssda.py @@ -4,7 +4,7 @@ from pydantic import BaseModel, ConfigDict, computed_field -from clinicadl.caps_dataset.extraction.utils import read_preprocessing +from clinicadl.utils.iotools.utils import read_preprocessing logger = getLogger("clinicadl.ssda_config") diff --git a/clinicadl/generate/generate_utils.py b/clinicadl/generate/generate_utils.py index 3e65b5949..2749e9fae 100755 --- a/clinicadl/generate/generate_utils.py +++ b/clinicadl/generate/generate_utils.py @@ -10,48 +10,11 @@ from scipy.ndimage import gaussian_filter from skimage.draw import ellipse -from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig -from clinicadl.caps_dataset.data_utils import check_multi_cohort_tsv -from clinicadl.caps_dataset.preprocessing.config import ( - PETPreprocessingConfig, - T1PreprocessingConfig, -) -from clinicadl.utils.clinica_utils import ( - FileType, - create_subs_sess_list, - linear_nii, - pet_linear_nii, -) -from clinicadl.utils.enum import ( - LinearModality, -) - -# from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig from clinicadl.utils.exceptions import ( - ClinicaDLArgumentError, ClinicaDLTSVError, ) - - -def find_file_type(config: CapsDatasetConfig) -> Dict[str, str]: - if isinstance(config.preprocessing, T1PreprocessingConfig): - file_type = linear_nii(config.preprocessing) - elif isinstance(config.preprocessing, PETPreprocessingConfig): - if ( - config.preprocessing.tracer is None - or config.preprocessing.suvr_reference_region is None - ): - raise ClinicaDLArgumentError( - "`tracer` and `suvr_reference_region` must be defined " - "when using `pet-linear` preprocessing." - ) - file_type = pet_linear_nii(config.preprocessing) - else: - raise NotImplementedError( - f"Generation of synthetic data is not implemented for preprocessing {config.preprocessing.preprocessing.value}" - ) - - return file_type +from clinicadl.utils.iotools.clinica_utils import create_subs_sess_list +from clinicadl.utils.iotools.data_utils import check_multi_cohort_tsv def write_missing_mods(output_dir: Path, output_df: pd.DataFrame) -> None: diff --git a/clinicadl/hugging_face/hugging_face.py b/clinicadl/hugging_face/hugging_face.py index 791229895..00f729e35 100644 --- a/clinicadl/hugging_face/hugging_face.py +++ b/clinicadl/hugging_face/hugging_face.py @@ -5,9 +5,9 @@ import toml +from clinicadl.caps_dataset.caps_dataset_utils import read_json from clinicadl.utils.exceptions import ClinicaDLArgumentError -from clinicadl.utils.maps_manager.maps_manager_utils import ( - read_json, +from clinicadl.utils.iotools.maps_manager_utils import ( remove_unused_tasks, ) diff --git a/clinicadl/interpret/config.py b/clinicadl/interpret/config.py index 1e59f3efc..abbf89b64 100644 --- a/clinicadl/interpret/config.py +++ b/clinicadl/interpret/config.py @@ -6,12 +6,12 @@ from clinicadl.caps_dataset.data_config import DataConfig as DataBaseConfig from clinicadl.caps_dataset.dataloader_config import DataLoaderConfig -from clinicadl.config.config.computational import ComputationalConfig -from clinicadl.config.config.cross_validation import CrossValidationConfig -from clinicadl.config.config.maps_manager import MapsManagerConfig -from clinicadl.config.config.validation import ValidationConfig from clinicadl.interpret.gradients import GradCam, Gradients, VanillaBackProp +from clinicadl.maps_manager.config import MapsManagerConfig +from clinicadl.utils.computational.computational import ComputationalConfig from clinicadl.utils.enum import InterpretationMethod +from clinicadl.validation.cross_validation import CrossValidationConfig +from clinicadl.validation.validation import ValidationConfig logger = getLogger("clinicadl.interpret_config") diff --git a/clinicadl/config/config/pipelines/__init__.py b/clinicadl/maps_manager/__init__.py similarity index 100% rename from clinicadl/config/config/pipelines/__init__.py rename to clinicadl/maps_manager/__init__.py diff --git a/clinicadl/config/config/maps_manager.py b/clinicadl/maps_manager/config.py similarity index 100% rename from clinicadl/config/config/maps_manager.py rename to clinicadl/maps_manager/config.py diff --git a/clinicadl/utils/maps_manager/maps_manager.py b/clinicadl/maps_manager/maps_manager.py similarity index 98% rename from clinicadl/utils/maps_manager/maps_manager.py rename to clinicadl/maps_manager/maps_manager.py index 9888af3f8..d02d3c5a3 100644 --- a/clinicadl/utils/maps_manager/maps_manager.py +++ b/clinicadl/maps_manager/maps_manager.py @@ -10,21 +10,22 @@ import torch.distributed as dist from torch.cuda.amp import autocast +from clinicadl.caps_dataset.caps_dataset_utils import read_json from clinicadl.caps_dataset.data import ( return_dataset, ) -from clinicadl.caps_dataset.extraction.utils import path_encoder from clinicadl.transforms.config import TransformsConfig +from clinicadl.utils import cluster +from clinicadl.utils.computational.ddp import DDP, init_ddp from clinicadl.utils.exceptions import ( ClinicaDLArgumentError, ClinicaDLConfigurationError, MAPSError, ) -from clinicadl.utils.maps_manager.ddp import DDP, cluster, init_ddp -from clinicadl.utils.maps_manager.maps_manager_utils import ( +from clinicadl.utils.iotools.maps_manager_utils import ( add_default_values, - read_json, ) +from clinicadl.utils.iotools.utils import path_encoder logger = getLogger("clinicadl.maps_manager") level_list: List[str] = ["warning", "info", "debug"] @@ -566,7 +567,7 @@ def _write_requirements_version(self): def _write_training_data(self): """Writes the TSV file containing the participant and session IDs used for training.""" logger.debug("Writing training data...") - from clinicadl.caps_dataset.data_utils import load_data_test + from clinicadl.utils.iotools.data_utils import load_data_test train_df = load_data_test( self.tsv_path, @@ -919,7 +920,7 @@ def _init_model( return model, current_epoch def _init_split_manager(self, split_list=None, ssda_bool: bool = False): - from clinicadl.utils import split_manager + from clinicadl.validation import split_manager split_class = getattr(split_manager, self.validation) args = list( @@ -941,7 +942,7 @@ def _init_split_manager(self, split_list=None, ssda_bool: bool = False): def _init_split_manager_ssda(self, caps_dir, tsv_dir, split_list=None): # A intégrer directement dans _init_split_manager - from clinicadl.utils import split_manager + from clinicadl.validation import split_manager split_class = getattr(split_manager, self.validation) args = list( diff --git a/clinicadl/predict/config.py b/clinicadl/predict/config.py index 19d14e4f3..9304eefd8 100644 --- a/clinicadl/predict/config.py +++ b/clinicadl/predict/config.py @@ -2,14 +2,13 @@ from clinicadl.caps_dataset.data_config import DataConfig as DataBaseConfig from clinicadl.caps_dataset.dataloader_config import DataLoaderConfig -from clinicadl.config.config.maps_manager import ( +from clinicadl.maps_manager.config import ( MapsManagerConfig as MapsManagerBaseConfig, ) +from clinicadl.utils.computational.computational import ComputationalConfig from clinicadl.utils.exceptions import ClinicaDLArgumentError # type: ignore - -from ..config.config.computational import ComputationalConfig -from ..config.config.cross_validation import CrossValidationConfig -from ..config.config.validation import ValidationConfig +from clinicadl.validation.cross_validation import CrossValidationConfig +from clinicadl.validation.validation import ValidationConfig logger = getLogger("clinicadl.predict_config") diff --git a/clinicadl/predict/predict_manager.py b/clinicadl/predict/predict_manager.py index fd6084f0c..0bb07f8d8 100644 --- a/clinicadl/predict/predict_manager.py +++ b/clinicadl/predict/predict_manager.py @@ -15,15 +15,15 @@ return_dataset, ) from clinicadl.interpret.config import InterpretConfig +from clinicadl.maps_manager.maps_manager import MapsManager from clinicadl.predict.config import PredictConfig from clinicadl.transforms.config import TransformsConfig +from clinicadl.utils.computational.ddp import DDP, cluster from clinicadl.utils.exceptions import ( ClinicaDLArgumentError, ClinicaDLDataLeakageError, MAPSError, ) -from clinicadl.utils.maps_manager.ddp import DDP, cluster -from clinicadl.utils.maps_manager.maps_manager import MapsManager logger = getLogger("clinicadl.predict_manager") level_list: List[str] = ["warning", "info", "debug"] @@ -916,7 +916,7 @@ def get_group_info( df = pd.read_csv(group_path / "data.tsv", sep="\t") json_path = group_path / "maps.json" - from clinicadl.caps_dataset.extraction.utils import path_decoder + from clinicadl.utils.iotools.utils import path_decoder with json_path.open(mode="r") as f: parameters = json.load(f, object_hook=path_decoder) diff --git a/clinicadl/prepare_data/prepare_data.py b/clinicadl/prepare_data/prepare_data.py index 262d178e2..e9b7fc073 100644 --- a/clinicadl/prepare_data/prepare_data.py +++ b/clinicadl/prepare_data/prepare_data.py @@ -14,16 +14,16 @@ ExtractionROIConfig, ExtractionSliceConfig, ) -from clinicadl.caps_dataset.extraction.utils import write_preprocessing -from clinicadl.utils.clinica_utils import ( +from clinicadl.utils.enum import ExtractionMethod, Pattern, Preprocessing, Template +from clinicadl.utils.exceptions import ClinicaDLArgumentError +from clinicadl.utils.iotools.clinica_utils import ( check_caps_folder, clinicadl_file_reader, container_from_filename, determine_caps_or_bids, get_subject_session_list, ) -from clinicadl.utils.enum import ExtractionMethod, Pattern, Preprocessing, Template -from clinicadl.utils.exceptions import ClinicaDLArgumentError +from clinicadl.utils.iotools.utils import write_preprocessing from .prepare_data_utils import check_mask_list diff --git a/clinicadl/prepare_data/prepare_data_utils.py b/clinicadl/prepare_data/prepare_data_utils.py index 73207ec39..0acd2ec25 100644 --- a/clinicadl/prepare_data/prepare_data_utils.py +++ b/clinicadl/prepare_data/prepare_data_utils.py @@ -1,6 +1,6 @@ # coding: utf8 from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np import torch diff --git a/clinicadl/quality_check/pet_linear/quality_check.py b/clinicadl/quality_check/pet_linear/quality_check.py index 1b4e9d2a9..7c355b09c 100644 --- a/clinicadl/quality_check/pet_linear/quality_check.py +++ b/clinicadl/quality_check/pet_linear/quality_check.py @@ -13,14 +13,13 @@ from joblib import Parallel, delayed from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig -from clinicadl.utils.clinica_utils import ( +from clinicadl.caps_dataset.preprocessing.utils import pet_linear_nii +from clinicadl.utils.iotools.clinica_utils import ( RemoteFileStructure, clinicadl_file_reader, fetch_file, get_subject_session_list, - pet_linear_nii, ) -from clinicadl.utils.enum import SUVRReferenceRegions, Tracer from .utils import get_metric diff --git a/clinicadl/quality_check/t1_linear/quality_check.py b/clinicadl/quality_check/t1_linear/quality_check.py index f0d82f980..f840a4583 100755 --- a/clinicadl/quality_check/t1_linear/quality_check.py +++ b/clinicadl/quality_check/t1_linear/quality_check.py @@ -12,10 +12,10 @@ from torch.utils.data import DataLoader from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig -from clinicadl.config.config.computational import ComputationalConfig from clinicadl.generate.generate_utils import load_and_check_tsv -from clinicadl.utils.clinica_utils import RemoteFileStructure, fetch_file +from clinicadl.utils.computational.computational import ComputationalConfig from clinicadl.utils.exceptions import ClinicaDLArgumentError +from clinicadl.utils.iotools.clinica_utils import RemoteFileStructure, fetch_file from .models import resnet_darq_qc_18 as darq_r18 from .models import resnet_deep_qc_18 as deep_r18 diff --git a/clinicadl/quality_check/t1_linear/utils.py b/clinicadl/quality_check/t1_linear/utils.py index d368bd2e4..b9f03ba3e 100755 --- a/clinicadl/quality_check/t1_linear/utils.py +++ b/clinicadl/quality_check/t1_linear/utils.py @@ -10,9 +10,10 @@ from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig from clinicadl.caps_dataset.caps_dataset_utils import compute_folder_and_file_type -from clinicadl.utils.clinica_utils import clinicadl_file_reader, linear_nii -from clinicadl.utils.enum import LinearModality, Preprocessing +from clinicadl.caps_dataset.preprocessing.utils import linear_nii +from clinicadl.utils.enum import Preprocessing from clinicadl.utils.exceptions import ClinicaDLException +from clinicadl.utils.iotools.clinica_utils import clinicadl_file_reader class QCDataset(Dataset): diff --git a/clinicadl/quality_check/t1_volume/utils.py b/clinicadl/quality_check/t1_volume/utils.py index 6d79d9c9f..6d22c82db 100644 --- a/clinicadl/quality_check/t1_volume/utils.py +++ b/clinicadl/quality_check/t1_volume/utils.py @@ -8,7 +8,7 @@ import numpy as np import pandas as pd -from clinicadl.utils.clinica_utils import RemoteFileStructure, fetch_file +from clinicadl.utils.iotools.clinica_utils import RemoteFileStructure, fetch_file def extract_metrics(caps_dir: Path, output_dir: Path, group_label): diff --git a/clinicadl/random_search/random_search_utils.py b/clinicadl/random_search/random_search_utils.py index 240232676..ed164ea0c 100644 --- a/clinicadl/random_search/random_search_utils.py +++ b/clinicadl/random_search/random_search_utils.py @@ -4,10 +4,10 @@ import toml -from clinicadl.caps_dataset.extraction.utils import path_decoder, read_preprocessing -from clinicadl.train.utils import extract_config_from_toml_file from clinicadl.utils.enum import Task from clinicadl.utils.exceptions import ClinicaDLConfigurationError +from clinicadl.utils.iotools.train_utils import extract_config_from_toml_file +from clinicadl.utils.iotools.utils import path_decoder, read_preprocessing def get_space_dict(launch_directory: Path) -> Dict[str, Any]: diff --git a/clinicadl/train/__init__.py b/clinicadl/train/__init__.py deleted file mode 100755 index e69de29bb..000000000 diff --git a/clinicadl/trainer/config/classification.py b/clinicadl/trainer/config/classification.py index 0fef7dadc..6472316f1 100644 --- a/clinicadl/trainer/config/classification.py +++ b/clinicadl/trainer/config/classification.py @@ -4,10 +4,10 @@ from pydantic import computed_field, field_validator from clinicadl.caps_dataset.data_config import DataConfig as BaseDataConfig -from clinicadl.config.config.validation import ValidationConfig as BaseValidationConfig from clinicadl.network.config import NetworkConfig as BaseNetworkConfig from clinicadl.trainer.config.train import TrainConfig from clinicadl.utils.enum import ClassificationLoss, ClassificationMetric, Task +from clinicadl.validation.validation import ValidationConfig as BaseValidationConfig logger = getLogger("clinicadl.classification_config") diff --git a/clinicadl/trainer/config/reconstruction.py b/clinicadl/trainer/config/reconstruction.py index 4651840ca..4ad9d5927 100644 --- a/clinicadl/trainer/config/reconstruction.py +++ b/clinicadl/trainer/config/reconstruction.py @@ -3,7 +3,6 @@ from pydantic import PositiveFloat, PositiveInt, computed_field, field_validator -from clinicadl.config.config.validation import ValidationConfig as BaseValidationConfig from clinicadl.network.config import NetworkConfig as BaseNetworkConfig from clinicadl.trainer.config.train import TrainConfig from clinicadl.utils.enum import ( @@ -12,6 +11,7 @@ ReconstructionMetric, Task, ) +from clinicadl.validation.validation import ValidationConfig as BaseValidationConfig logger = getLogger("clinicadl.reconstruction_config") diff --git a/clinicadl/trainer/config/regression.py b/clinicadl/trainer/config/regression.py index 227cbe232..d68a873f8 100644 --- a/clinicadl/trainer/config/regression.py +++ b/clinicadl/trainer/config/regression.py @@ -4,10 +4,10 @@ from pydantic import computed_field, field_validator from clinicadl.caps_dataset.data_config import DataConfig as BaseDataConfig -from clinicadl.config.config.validation import ValidationConfig as BaseValidationConfig from clinicadl.network.config import NetworkConfig as BaseNetworkConfig from clinicadl.trainer.config.train import TrainConfig from clinicadl.utils.enum import RegressionLoss, RegressionMetric, Task +from clinicadl.validation.validation import ValidationConfig as BaseValidationConfig logger = getLogger("clinicadl.reconstruction_config") logger = getLogger("clinicadl.regression_config") diff --git a/clinicadl/trainer/config/train.py b/clinicadl/trainer/config/train.py index 22dac31f2..3a3d791a8 100644 --- a/clinicadl/trainer/config/train.py +++ b/clinicadl/trainer/config/train.py @@ -12,20 +12,20 @@ from clinicadl.callbacks.config import CallbacksConfig from clinicadl.caps_dataset.data_config import DataConfig from clinicadl.caps_dataset.dataloader_config import DataLoaderConfig -from clinicadl.config.config.computational import ComputationalConfig -from clinicadl.config.config.cross_validation import CrossValidationConfig -from clinicadl.config.config.early_stopping import EarlyStoppingConfig from clinicadl.config.config.lr_scheduler import LRschedulerConfig -from clinicadl.config.config.maps_manager import MapsManagerConfig from clinicadl.config.config.reproducibility import ReproducibilityConfig from clinicadl.config.config.ssda import SSDAConfig -from clinicadl.config.config.transfer_learning import TransferLearningConfig -from clinicadl.config.config.validation import ValidationConfig +from clinicadl.maps_manager.config import MapsManagerConfig from clinicadl.network.config import NetworkConfig from clinicadl.optimizer.optimization import OptimizationConfig from clinicadl.optimizer.optimizer import OptimizerConfig +from clinicadl.trainer.transfer_learning import TransferLearningConfig from clinicadl.transforms.config import TransformsConfig +from clinicadl.utils.computational.computational import ComputationalConfig +from clinicadl.utils.early_stopping.config import EarlyStoppingConfig from clinicadl.utils.enum import Task +from clinicadl.validation.cross_validation import CrossValidationConfig +from clinicadl.validation.validation import ValidationConfig logger = getLogger("clinicadl.training_config") @@ -111,7 +111,7 @@ def update_with_toml(self, path: Union[str, Path]) -> None: path : Union[str, Path] Path to the TOML configuration file. """ - from clinicadl.train.utils import extract_config_from_toml_file + from clinicadl.utils.iotools.train_utils import extract_config_from_toml_file path = Path(path) config_dict = extract_config_from_toml_file(path, self.network_task) diff --git a/clinicadl/train/tasks_utils.py b/clinicadl/trainer/tasks_utils.py similarity index 100% rename from clinicadl/train/tasks_utils.py rename to clinicadl/trainer/tasks_utils.py diff --git a/clinicadl/trainer/trainer.py b/clinicadl/trainer/trainer.py index 22b8e01b4..96b9fbc4b 100644 --- a/clinicadl/trainer/trainer.py +++ b/clinicadl/trainer/trainer.py @@ -15,19 +15,22 @@ from torch.utils.data.distributed import DistributedSampler from clinicadl.caps_dataset.data import return_dataset -from clinicadl.utils.early_stopping import EarlyStopping +from clinicadl.utils.early_stopping.early_stopping import EarlyStopping from clinicadl.utils.exceptions import MAPSError -from clinicadl.utils.maps_manager.ddp import DDP, cluster -from clinicadl.utils.maps_manager.logwriter import LogWriter -from clinicadl.utils.maps_manager.maps_manager_utils import read_json +from clinicadl.utils.computational.ddp import DDP +from clinicadl.utils import cluster +from clinicadl.utils.logwriter import LogWriter +from clinicadl.caps_dataset.caps_dataset_utils import read_json from clinicadl.utils.metric_module import RetainBest from clinicadl.utils.seed import pl_worker_init_function, seed_everything -from clinicadl.transforms.config import TransformsConfig -from clinicadl.utils.maps_manager import MapsManager +from clinicadl.maps_manager.maps_manager import MapsManager from clinicadl.utils.seed import get_seed from clinicadl.utils.enum import Task -from .trainer_utils import create_parameters_dict, patch_to_read_json -from clinicadl.train.tasks_utils import create_training_config +from clinicadl.utils.iotools.trainer_utils import ( + create_parameters_dict, + patch_to_read_json, +) +from clinicadl.trainer.tasks_utils import create_training_config if TYPE_CHECKING: from clinicadl.callbacks.callbacks import Callback diff --git a/clinicadl/config/config/transfer_learning.py b/clinicadl/trainer/transfer_learning.py similarity index 100% rename from clinicadl/config/config/transfer_learning.py rename to clinicadl/trainer/transfer_learning.py diff --git a/clinicadl/tsvtools/get_labels/get_labels.py b/clinicadl/tsvtools/get_labels/get_labels.py index 041b301e9..baa6f48a2 100644 --- a/clinicadl/tsvtools/get_labels/get_labels.py +++ b/clinicadl/tsvtools/get_labels/get_labels.py @@ -27,7 +27,7 @@ neighbour_session, ) from clinicadl.utils.exceptions import ClinicaDLArgumentError, ClinicaDLTSVError -from clinicadl.utils.maps_manager.iotools import commandline_to_json +from clinicadl.utils.iotools.iotools import commandline_to_json logger = getLogger("clinicadl.tsvtools") diff --git a/clinicadl/tsvtools/getlabels/getlabels.py b/clinicadl/tsvtools/getlabels/getlabels.py index a2f2d84d2..8176454d2 100644 --- a/clinicadl/tsvtools/getlabels/getlabels.py +++ b/clinicadl/tsvtools/getlabels/getlabels.py @@ -28,7 +28,7 @@ neighbour_session, ) from clinicadl.utils.exceptions import ClinicaDLArgumentError, ClinicaDLTSVError -from clinicadl.utils.maps_manager.iotools import commandline_to_json +from clinicadl.utils.iotools.iotools import commandline_to_json logger = getLogger("clinicadl") diff --git a/clinicadl/tsvtools/kfold/kfold.py b/clinicadl/tsvtools/kfold/kfold.py index 12ffddd4b..04b88872e 100644 --- a/clinicadl/tsvtools/kfold/kfold.py +++ b/clinicadl/tsvtools/kfold/kfold.py @@ -9,7 +9,7 @@ from clinicadl.tsvtools.tsvtools_utils import extract_baseline, retrieve_longitudinal from clinicadl.utils.exceptions import ClinicaDLTSVError -from clinicadl.utils.maps_manager.iotools import commandline_to_json +from clinicadl.utils.iotools.iotools import commandline_to_json sex_dict = {"M": 0, "F": 1} logger = getLogger("clinicadl.tsvtools.kfold") diff --git a/clinicadl/tsvtools/split/split.py b/clinicadl/tsvtools/split/split.py index cef621714..6235ee2e8 100644 --- a/clinicadl/tsvtools/split/split.py +++ b/clinicadl/tsvtools/split/split.py @@ -19,7 +19,7 @@ retrieve_longitudinal, ) from clinicadl.utils.exceptions import ClinicaDLArgumentError, ClinicaDLTSVError -from clinicadl.utils.maps_manager.iotools import commandline_to_json +from clinicadl.utils.iotools.iotools import commandline_to_json sex_dict = {"M": 0, "F": 1} logger = getLogger("clinicadl.tsvtools.split") diff --git a/clinicadl/utils/maps_manager/cluster/__init__.py b/clinicadl/utils/cluster/__init__.py similarity index 100% rename from clinicadl/utils/maps_manager/cluster/__init__.py rename to clinicadl/utils/cluster/__init__.py diff --git a/clinicadl/utils/maps_manager/cluster/api/__init__.py b/clinicadl/utils/cluster/api/__init__.py similarity index 100% rename from clinicadl/utils/maps_manager/cluster/api/__init__.py rename to clinicadl/utils/cluster/api/__init__.py diff --git a/clinicadl/utils/maps_manager/cluster/api/auto_master_addr_port.py b/clinicadl/utils/cluster/api/auto_master_addr_port.py similarity index 100% rename from clinicadl/utils/maps_manager/cluster/api/auto_master_addr_port.py rename to clinicadl/utils/cluster/api/auto_master_addr_port.py diff --git a/clinicadl/utils/maps_manager/cluster/api/base.py b/clinicadl/utils/cluster/api/base.py similarity index 100% rename from clinicadl/utils/maps_manager/cluster/api/base.py rename to clinicadl/utils/cluster/api/base.py diff --git a/clinicadl/utils/maps_manager/cluster/api/default.py b/clinicadl/utils/cluster/api/default.py similarity index 100% rename from clinicadl/utils/maps_manager/cluster/api/default.py rename to clinicadl/utils/cluster/api/default.py diff --git a/clinicadl/utils/maps_manager/cluster/api/slurm.py b/clinicadl/utils/cluster/api/slurm.py similarity index 100% rename from clinicadl/utils/maps_manager/cluster/api/slurm.py rename to clinicadl/utils/cluster/api/slurm.py diff --git a/clinicadl/utils/maps_manager/cluster/api/torchelastic.py b/clinicadl/utils/cluster/api/torchelastic.py similarity index 100% rename from clinicadl/utils/maps_manager/cluster/api/torchelastic.py rename to clinicadl/utils/cluster/api/torchelastic.py diff --git a/clinicadl/utils/maps_manager/cluster/config.py b/clinicadl/utils/cluster/config.py similarity index 100% rename from clinicadl/utils/maps_manager/cluster/config.py rename to clinicadl/utils/cluster/config.py diff --git a/clinicadl/utils/maps_manager/cluster/interface.py b/clinicadl/utils/cluster/interface.py similarity index 100% rename from clinicadl/utils/maps_manager/cluster/interface.py rename to clinicadl/utils/cluster/interface.py diff --git a/clinicadl/utils/maps_manager/cluster/utils.py b/clinicadl/utils/cluster/utils.py similarity index 100% rename from clinicadl/utils/maps_manager/cluster/utils.py rename to clinicadl/utils/cluster/utils.py diff --git a/clinicadl/config/config/computational.py b/clinicadl/utils/computational/computational.py similarity index 100% rename from clinicadl/config/config/computational.py rename to clinicadl/utils/computational/computational.py diff --git a/clinicadl/utils/maps_manager/ddp.py b/clinicadl/utils/computational/ddp.py similarity index 99% rename from clinicadl/utils/maps_manager/ddp.py rename to clinicadl/utils/computational/ddp.py index cbd6cc5c4..b4f30afbc 100644 --- a/clinicadl/utils/maps_manager/ddp.py +++ b/clinicadl/utils/computational/ddp.py @@ -31,7 +31,7 @@ else: fsdp_available = True -from . import cluster +from clinicadl.utils import cluster logger = logging.getLogger("DDP") diff --git a/clinicadl/config/config/early_stopping.py b/clinicadl/utils/early_stopping/config.py similarity index 100% rename from clinicadl/config/config/early_stopping.py rename to clinicadl/utils/early_stopping/config.py diff --git a/clinicadl/utils/early_stopping.py b/clinicadl/utils/early_stopping/early_stopping.py similarity index 100% rename from clinicadl/utils/early_stopping.py rename to clinicadl/utils/early_stopping/early_stopping.py diff --git a/clinicadl/utils/iotools/__init__.py b/clinicadl/utils/iotools/__init__.py new file mode 100644 index 000000000..f2c3432c4 --- /dev/null +++ b/clinicadl/utils/iotools/__init__.py @@ -0,0 +1,42 @@ +from .clinica_utils import ( + FileType, + check_bids_folder, + check_caps_folder, + clinicadl_file_reader, + container_from_filename, + create_subs_sess_list, + determine_caps_or_bids, + fetch_file, + find_sub_ses_pattern_path, + get_filename_no_ext, + get_subject_session_list, + insensitive_glob, + read_participant_tsv, +) +from .data_utils import ( + check_multi_cohort_tsv, + check_test_path, + load_data_test, + load_data_test_single, +) +from .iotools import ( + check_and_clean, + check_and_complete, + commandline_to_json, + cpuStats, + memReport, + write_requirements_version, +) +from .maps_manager_utils import add_default_values, remove_unused_tasks +from .read_utils import ( + get_info_from_filename, + get_mask_checksum_and_filename, + get_mask_path, +) +from .train_utils import ( + extract_config_from_toml_file, + get_model_list, + merge_cli_and_config_file_options, +) +from .trainer_utils import create_parameters_dict, patch_to_read_json +from .utils import path_decoder, path_encoder, read_preprocessing, write_preprocessing diff --git a/clinicadl/utils/clinica_utils.py b/clinicadl/utils/iotools/clinica_utils.py similarity index 86% rename from clinicadl/utils/clinica_utils.py rename to clinicadl/utils/iotools/clinica_utils.py index 463a305ab..702a2b599 100644 --- a/clinicadl/utils/clinica_utils.py +++ b/clinicadl/utils/iotools/clinica_utils.py @@ -15,18 +15,7 @@ import pandas as pd from pydantic import BaseModel -from clinicadl.caps_dataset.preprocessing import config as preprocessing_config -from clinicadl.utils.enum import ( - DTIMeasure, - DTISpace, - ImageModality, - LinearModality, - Preprocessing, - SUVRReferenceRegions, - Tracer, -) from clinicadl.utils.exceptions import ( - ClinicaDLArgumentError, ClinicaDLBIDSError, ClinicaDLCAPSError, ) @@ -41,143 +30,6 @@ class FileType(BaseModel): needed_pipeline: Optional[str] = None -def bids_nii( - config: preprocessing_config.PreprocessingConfig, - reconstruction: Optional[str] = None, -) -> FileType: - """Return the query dict required to capture PET scans. - - Parameters - ---------- - tracer : Tracer, optional - If specified, the query will only match PET scans acquired - with the requested tracer. - If None, the query will match all PET sans independently of - the tracer used. - - reconstruction : ReconstructionMethod, optional - If specified, the query will only match PET scans reconstructed - with the specified method. - If None, the query will match all PET scans independently of the - reconstruction method used. - - Returns - ------- - dict : - The query dictionary to get PET scans. - """ - - if config.preprocessing not in Preprocessing: - raise ClinicaDLArgumentError( - f"ClinicaDL is Unable to read this modality ({config.preprocessing}) of images, please chose one from this list: {list[Preprocessing]}" - ) - - if isinstance(config, preprocessing_config.PETPreprocessingConfig): - trc = "" if config.tracer is None else f"_trc-{Tracer(config.tracer).value}" - rec = "" if reconstruction is None else f"_rec-{reconstruction}" - description = "PET data" - - if config.tracer: - description += f" with {config.tracer.value} tracer" - if reconstruction: - description += f" and reconstruction method {reconstruction}" - - file_type = FileType( - pattern=f"pet/*{trc}{rec}_pet.nii*", description=description - ) - return file_type - - elif isinstance(config, preprocessing_config.T1PreprocessingConfig): - return FileType(pattern="anat/sub-*_ses-*_T1w.nii*", description="T1w MRI") - - elif isinstance(config, preprocessing_config.FlairPreprocessingConfig): - return FileType(pattern="sub-*_ses-*_flair.nii*", description="FLAIR T2w MRI") - - elif isinstance(config, preprocessing_config.DTIPreprocessingConfig): - return FileType(pattern="dwi/sub-*_ses-*_dwi.nii*", description="DWI NIfTI") - - else: - raise ClinicaDLArgumentError("Invalid preprocessing") - - -def linear_nii( - config: preprocessing_config.PreprocessingConfig, -) -> FileType: - if isinstance(config, preprocessing_config.T1PreprocessingConfig): - needed_pipeline = Preprocessing.T1_LINEAR - modality = LinearModality.T1W - elif isinstance(config, preprocessing_config.T2PreprocessingConfig): - needed_pipeline = Preprocessing.T2_LINEAR - modality = LinearModality.T2W - elif isinstance(config, preprocessing_config.FlairPreprocessingConfig): - needed_pipeline = Preprocessing.FLAIR_LINEAR - modality = LinearModality.FLAIR - else: - raise ClinicaDLArgumentError("Invalid configuration") - - if config.use_uncropped_image: - desc_crop = "" - else: - desc_crop = "_desc-Crop" - - file_type = FileType( - pattern=f"*space-MNI152NLin2009cSym{desc_crop}_res-1x1x1_{modality.value}.nii.gz", - description=f"{modality.value} Image registered in MNI152NLin2009cSym space using {needed_pipeline.value} pipeline " - + ( - "" - if config.use_uncropped_image - else "and cropped (matrix size 169×208×179, 1 mm isotropic voxels)" - ), - needed_pipeline=needed_pipeline, - ) - return file_type - - -def dwi_dti(config: preprocessing_config.DTIPreprocessingConfig) -> FileType: - """Return the query dict required to capture DWI DTI images. - - Parameters - ---------- - config: DTIPreprocessingConfig - - Returns - ------- - FileType : - """ - if isinstance(config, preprocessing_config.DTIPreprocessingConfig): - measure = config.dti_measure - space = config.dti_space - else: - raise ClinicaDLArgumentError( - f"PreprocessingConfig is of type {config} but should be of type{preprocessing_config.DTIPreprocessingConfig}" - ) - - return FileType( - pattern=f"dwi/dti_based_processing/*/*_space-{space}_{measure.value}.nii.gz", - description=f"DTI-based {measure.value} in space {space}.", - needed_pipeline="dwi_dti", - ) - - -def pet_linear_nii(config: preprocessing_config.PETPreprocessingConfig) -> FileType: - if not isinstance(config, preprocessing_config.PETPreprocessingConfig): - raise ClinicaDLArgumentError( - f"PreprocessingConfig is of type {config} but should be of type{preprocessing_config.PETPreprocessingConfig}" - ) - - if config.use_uncropped_image: - description = "" - else: - description = "_desc-Crop" - - file_type = FileType( - pattern=f"pet_linear/*_trc-{config.tracer.value}_space-MNI152NLin2009cSym{description}_res-1x1x1_suvr-{config.suvr_reference_region.value}_pet.nii.gz", - description="", - needed_pipeline="pet-linear", - ) - return file_type - - def container_from_filename(bids_or_caps_filename: Path) -> Path: """Extract container from BIDS or CAPS file. diff --git a/clinicadl/caps_dataset/data_utils.py b/clinicadl/utils/iotools/data_utils.py similarity index 99% rename from clinicadl/caps_dataset/data_utils.py rename to clinicadl/utils/iotools/data_utils.py index 66ea5d8a0..96d3268c9 100644 --- a/clinicadl/caps_dataset/data_utils.py +++ b/clinicadl/utils/iotools/data_utils.py @@ -2,7 +2,6 @@ # TODO: create a folder for generate/ prepare_data/ data to deal with capsDataset objects ? from logging import getLogger from pathlib import Path -from typing import Any, Callable, Dict, Optional import pandas as pd diff --git a/clinicadl/utils/maps_manager/iotools.py b/clinicadl/utils/iotools/iotools.py similarity index 100% rename from clinicadl/utils/maps_manager/iotools.py rename to clinicadl/utils/iotools/iotools.py diff --git a/clinicadl/utils/iotools/maps_manager_utils.py b/clinicadl/utils/iotools/maps_manager_utils.py new file mode 100644 index 000000000..5e2d6d320 --- /dev/null +++ b/clinicadl/utils/iotools/maps_manager_utils.py @@ -0,0 +1,75 @@ +import json +from pathlib import Path +from typing import Any, Dict + +import toml + +from clinicadl.utils.iotools.utils import path_decoder + + +def add_default_values(user_dict: Dict[str, Any]) -> Dict[str, Any]: + """ + Updates the training parameters defined by the user with the default values in missing fields. + + Args: + user_dict: dictionary of training parameters defined by the user. + + Returns: + dictionary of values ready to use for the training process. + """ + task = user_dict["network_task"] + # read default values + clinicadl_root_dir = Path(__file__).parents[2] + config_path = clinicadl_root_dir / "resources" / "config" / "train_config.toml" + # from clinicadl.utils.preprocessing import path_decoder + config_dict = toml.load(config_path) + # config_dict = path_decoder(config_dict) + + # task dependent + config_dict = remove_unused_tasks(config_dict, task) + + # Check that TOML file has the same format as the one in resources + for section_name in config_dict: + for key in config_dict[section_name]: + if key not in user_dict: # Add value if not present in user_dict + user_dict[key] = config_dict[section_name][key] + + # Hard-coded options + if user_dict["n_splits"] and user_dict["n_splits"] > 1: + user_dict["validation"] = "KFoldSplit" + else: + user_dict["validation"] = "SingleSplit" + + user_dict = path_decoder(user_dict) + + return user_dict + + +def remove_unused_tasks( + toml_dict: Dict[str, Dict[str, Any]], task: str +) -> Dict[str, Dict[str, Any]]: + """ + Remove options depending on other tasks than task + + Args: + toml_dict: dictionary of options as written in a TOML file. + task: task which will be learnt by the network. + + Returns: + updated TOML dictionary. + """ + task_list = ["classification", "regression", "reconstruction"] + + if task not in task_list: + raise ValueError( + f"Invalid value for network_task {task}. " + f"Please task choose in {task_list}." + ) + task_list.remove(task) + + # Delete all sections related to other tasks + for other_task in task_list: + if other_task.capitalize() in toml_dict: + del toml_dict[other_task.capitalize()] + + return toml_dict diff --git a/clinicadl/utils/read_utils.py b/clinicadl/utils/iotools/read_utils.py similarity index 98% rename from clinicadl/utils/read_utils.py rename to clinicadl/utils/iotools/read_utils.py index b57bda2b8..e73688521 100644 --- a/clinicadl/utils/read_utils.py +++ b/clinicadl/utils/iotools/read_utils.py @@ -2,12 +2,12 @@ from pathlib import Path from typing import Optional, Tuple -from clinicadl.utils.clinica_utils import ( +from clinicadl.utils.enum import MaskChecksum, Pathology +from clinicadl.utils.exceptions import ClinicaDLArgumentError, DownloadError +from clinicadl.utils.iotools.clinica_utils import ( RemoteFileStructure, fetch_file, ) -from clinicadl.utils.enum import MaskChecksum, Pathology -from clinicadl.utils.exceptions import ClinicaDLArgumentError, DownloadError def get_info_from_filename(image_path: Path) -> Tuple[str, str, str, str]: diff --git a/clinicadl/train/utils.py b/clinicadl/utils/iotools/train_utils.py similarity index 97% rename from clinicadl/train/utils.py rename to clinicadl/utils/iotools/train_utils.py index 2e92ec2fa..e4347de3b 100644 --- a/clinicadl/train/utils.py +++ b/clinicadl/utils/iotools/train_utils.py @@ -6,10 +6,10 @@ import toml from click.core import ParameterSource -from clinicadl.caps_dataset.extraction.utils import path_decoder from clinicadl.utils.enum import Task from clinicadl.utils.exceptions import ClinicaDLConfigurationError -from clinicadl.utils.maps_manager.maps_manager_utils import remove_unused_tasks +from clinicadl.utils.iotools.maps_manager_utils import remove_unused_tasks +from clinicadl.utils.iotools.utils import path_decoder def extract_config_from_toml_file(config_file: Path, task: Task) -> Dict[str, Any]: @@ -49,7 +49,7 @@ def extract_config_from_toml_file(config_file: Path, task: Task) -> Dict[str, An del user_dict["Random_Search"] # get the template - clinicadl_root_dir = Path(__file__).parents[1] + clinicadl_root_dir = Path(__file__).parents[2] config_path = clinicadl_root_dir / "resources" / "config" / "train_config.toml" config_dict = toml.load(config_path) # Check that TOML file has the same format as the one in clinicadl/resources/config/train_config.toml @@ -193,6 +193,7 @@ def merge_cli_and_config_file_options(task: Task, **kwargs) -> Dict[str, Any]: # TODO : remove? try: options["maps_dir"] = options["output_maps_directory"] + except KeyError: pass ### diff --git a/clinicadl/trainer/trainer_utils.py b/clinicadl/utils/iotools/trainer_utils.py similarity index 100% rename from clinicadl/trainer/trainer_utils.py rename to clinicadl/utils/iotools/trainer_utils.py diff --git a/clinicadl/caps_dataset/extraction/utils.py b/clinicadl/utils/iotools/utils.py similarity index 100% rename from clinicadl/caps_dataset/extraction/utils.py rename to clinicadl/utils/iotools/utils.py diff --git a/clinicadl/utils/maps_manager/logwriter.py b/clinicadl/utils/logwriter.py similarity index 100% rename from clinicadl/utils/maps_manager/logwriter.py rename to clinicadl/utils/logwriter.py diff --git a/clinicadl/utils/maps_manager/__init__.py b/clinicadl/utils/maps_manager/__init__.py deleted file mode 100644 index 92fbfae72..000000000 --- a/clinicadl/utils/maps_manager/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .iotools import ( - check_and_complete, - commandline_to_json, - write_requirements_version, -) -from .logwriter import LogWriter -from .maps_manager import MapsManager diff --git a/clinicadl/utils/maps_manager/cluster/profiler/__init__.py b/clinicadl/utils/maps_manager/cluster/profiler/__init__.py deleted file mode 100644 index 5aea07d28..000000000 --- a/clinicadl/utils/maps_manager/cluster/profiler/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from torch.profiler import * - -from .patch_kineto import tensorboard_trace_handler diff --git a/clinicadl/utils/maps_manager/cluster/profiler/patch_kineto.py b/clinicadl/utils/maps_manager/cluster/profiler/patch_kineto.py deleted file mode 100644 index 2cc00602c..000000000 --- a/clinicadl/utils/maps_manager/cluster/profiler/patch_kineto.py +++ /dev/null @@ -1,86 +0,0 @@ -import os -import sys -from functools import wraps -from re import sub -from typing import Optional - -import torch.profiler -from packaging.version import Version - - -class persistent_locals(object): - """ - Allows access to local variables of a function. - Shamelessly stolen from - https://stackoverflow.com/questions/9186395/python-is-there-a-way-to-get-a-local-function-variable-from-within-a-decorator - """ - - def __init__(self, func): - self._locals = {} - self.func = func - - def __call__(self, *args, **kwargs): - def tracer(frame, event, arg): - if event == "return": - self._locals = frame.f_locals.copy() - - # tracer is activated on next call, return or exception - sys.setprofile(tracer) - try: - # trace the function call - res = self.func(*args, **kwargs) - finally: - # disable tracer and replace with old one - sys.setprofile(None) - return res - - def clear_locals(self): - self._locals = {} - - @property - def locals(self): - return self._locals - - -if Version(torch.__version__) >= Version("1.12.0"): - # This tensorboard_trace_handler wraps Pytorch's version. It restores a feature - # of Kineto profiler which was lost when upgrading Pytorch from 1.11 to 1.12. - # In the profiler, some category names were changed. But in the tensorboard - # visualization from Kineto, those category have not been renamed accordingly. - # This loses the dataloader step profiling. We can restore this feature by - # renaming the category in the output trace file. - - @wraps(torch.profiler.tensorboard_trace_handler) - def tensorboard_trace_handler( - dir_name: str, worker_name: Optional[str] = None, use_gzip: bool = False - ): - handler_fn = torch.profiler.tensorboard_trace_handler( - dir_name=dir_name, - worker_name=worker_name, - use_gzip=use_gzip, - ) - handler_fn = persistent_locals(handler_fn) - - @wraps(handler_fn) - def wrapper(prof): - handler_fn(prof) - file_name = handler_fn._locals["file_name"] - dir_name = handler_fn._locals["dir_name"] - handler_fn.clear_locals() - with open(os.path.join(dir_name, file_name), "r+") as file: - content = file.read() - file.seek(0) - file.truncate() - - # Restore profiler steps and dataloader - new_content = sub("user_annotation", "cpu_op", content) - - # Restore runtime category - new_content = sub("cuda_runtime", "runtime", new_content) - - file.write(new_content) - - return wrapper - -else: - tensorboard_trace_handler = torch.profiler.tensorboard_trace_handler diff --git a/clinicadl/utils/maps_manager/maps_manager_utils.py b/clinicadl/utils/maps_manager/maps_manager_utils.py deleted file mode 100644 index acb34df7a..000000000 --- a/clinicadl/utils/maps_manager/maps_manager_utils.py +++ /dev/null @@ -1,190 +0,0 @@ -import json -from pathlib import Path -from typing import Any, Dict - -import toml - -from clinicadl.caps_dataset.caps_dataset_utils import compute_folder_and_file_type -from clinicadl.caps_dataset.extraction.utils import path_decoder - - -def add_default_values(user_dict: Dict[str, Any]) -> Dict[str, Any]: - """ - Updates the training parameters defined by the user with the default values in missing fields. - - Args: - user_dict: dictionary of training parameters defined by the user. - - Returns: - dictionary of values ready to use for the training process. - """ - task = user_dict["network_task"] - # read default values - clinicadl_root_dir = Path(__file__).parents[2] - config_path = clinicadl_root_dir / "resources" / "config" / "train_config.toml" - # from clinicadl.utils.preprocessing import path_decoder - config_dict = toml.load(config_path) - # config_dict = path_decoder(config_dict) - - # task dependent - config_dict = remove_unused_tasks(config_dict, task) - - # Check that TOML file has the same format as the one in resources - for section_name in config_dict: - for key in config_dict[section_name]: - if key not in user_dict: # Add value if not present in user_dict - user_dict[key] = config_dict[section_name][key] - - # Hard-coded options - if user_dict["n_splits"] and user_dict["n_splits"] > 1: - user_dict["validation"] = "KFoldSplit" - else: - user_dict["validation"] = "SingleSplit" - - user_dict = path_decoder(user_dict) - - return user_dict - - -def read_json(json_path: Path) -> Dict[str, Any]: - """ - Ensures retro-compatibility between the different versions of ClinicaDL. - - Parameters - ---------- - json_path: Path - path to the JSON file summing the parameters of a MAPS. - - Returns - ------- - A dictionary of training parameters. - """ - from clinicadl.caps_dataset.extraction.utils import path_decoder - - with json_path.open(mode="r") as f: - parameters = json.load(f, object_hook=path_decoder) - # Types of retro-compatibility - # Change arg name: ex network --> model - # Change arg value: ex for preprocessing: mni --> t1-extensive - # New arg with default hard-coded value --> discarded_slice --> 20 - retro_change_name = { - "model": "architecture", - "multi": "multi_network", - "minmaxnormalization": "normalize", - "num_workers": "n_proc", - "mode": "extract_method", - } - - retro_add = { - "optimizer": "Adam", - "loss": None, - } - - for old_name, new_name in retro_change_name.items(): - if old_name in parameters: - parameters[new_name] = parameters[old_name] - del parameters[old_name] - - for name, value in retro_add.items(): - if name not in parameters: - parameters[name] = value - - if "extract_method" in parameters: - parameters["mode"] = parameters["extract_method"] - # Value changes - if "use_cpu" in parameters: - parameters["gpu"] = not parameters["use_cpu"] - del parameters["use_cpu"] - if "nondeterministic" in parameters: - parameters["deterministic"] = not parameters["nondeterministic"] - del parameters["nondeterministic"] - - # Build preprocessing_dict - if "preprocessing_dict" not in parameters: - parameters["preprocessing_dict"] = {"mode": parameters["mode"]} - preprocessing_options = [ - "preprocessing", - "use_uncropped_image", - "prepare_dl", - "custom_suffix", - "tracer", - "suvr_reference_region", - "patch_size", - "stride_size", - "slice_direction", - "slice_mode", - "discarded_slices", - "roi_list", - "uncropped_roi", - "roi_custom_suffix", - "roi_custom_template", - "roi_custom_mask_pattern", - ] - for preprocessing_var in preprocessing_options: - if preprocessing_var in parameters: - parameters["preprocessing_dict"][preprocessing_var] = parameters[ - preprocessing_var - ] - del parameters[preprocessing_var] - - # Add missing parameters in previous version of extract - if "use_uncropped_image" not in parameters["preprocessing_dict"]: - parameters["preprocessing_dict"]["use_uncropped_image"] = False - - if ( - "prepare_dl" not in parameters["preprocessing_dict"] - and parameters["mode"] != "image" - ): - parameters["preprocessing_dict"]["prepare_dl"] = False - - if ( - parameters["mode"] == "slice" - and "slice_mode" not in parameters["preprocessing_dict"] - ): - parameters["preprocessing_dict"]["slice_mode"] = "rgb" - - if "preprocessing" not in parameters: - parameters["preprocessing"] = parameters["preprocessing_dict"]["preprocessing"] - - from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig - - config = CapsDatasetConfig.from_preprocessing_and_extraction_method( - extraction=parameters["mode"], - preprocessing_type=parameters["preprocessing"], - **parameters, - ) - if "file_type" not in parameters["preprocessing_dict"]: - _, file_type = compute_folder_and_file_type(config) - parameters["preprocessing_dict"]["file_type"] = file_type.model_dump() - - return parameters - - -def remove_unused_tasks( - toml_dict: Dict[str, Dict[str, Any]], task: str -) -> Dict[str, Dict[str, Any]]: - """ - Remove options depending on other tasks than task - - Args: - toml_dict: dictionary of options as written in a TOML file. - task: task which will be learnt by the network. - - Returns: - updated TOML dictionary. - """ - task_list = ["classification", "regression", "reconstruction"] - - if task not in task_list: - raise ValueError( - f"Invalid value for network_task {task}. " - f"Please task choose in {task_list}." - ) - task_list.remove(task) - - # Delete all sections related to other tasks - for other_task in task_list: - if other_task.capitalize() in toml_dict: - del toml_dict[other_task.capitalize()] - - return toml_dict diff --git a/clinicadl/utils/meta_maps/getter.py b/clinicadl/utils/meta_maps/getter.py index 38307b11d..96d16ae67 100644 --- a/clinicadl/utils/meta_maps/getter.py +++ b/clinicadl/utils/meta_maps/getter.py @@ -6,8 +6,8 @@ import pandas as pd +from clinicadl.maps_manager.maps_manager import MapsManager from clinicadl.utils.exceptions import MAPSError -from clinicadl.utils.maps_manager import MapsManager def meta_maps_analysis(launch_dir: Path, evaluation_metric="loss"): diff --git a/clinicadl/utils/task_manager/task_manager.py b/clinicadl/utils/task_manager/task_manager.py index a0eb931a3..0561d2d27 100644 --- a/clinicadl/utils/task_manager/task_manager.py +++ b/clinicadl/utils/task_manager/task_manager.py @@ -11,7 +11,7 @@ from clinicadl.caps_dataset.data import CapsDataset from clinicadl.network.network import Network -from clinicadl.utils.maps_manager.ddp import cluster +from clinicadl.utils import cluster from clinicadl.utils.metric_module import MetricModule diff --git a/clinicadl/config/config/cross_validation.py b/clinicadl/validation/cross_validation.py similarity index 94% rename from clinicadl/config/config/cross_validation.py rename to clinicadl/validation/cross_validation.py index 3441d72d1..e7c4ed7f5 100644 --- a/clinicadl/config/config/cross_validation.py +++ b/clinicadl/validation/cross_validation.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, ConfigDict, field_validator from pydantic.types import NonNegativeInt -from clinicadl.utils.maps_manager.maps_manager import MapsManager +from clinicadl.maps_manager.maps_manager import MapsManager logger = getLogger("clinicadl.cross_validation_config") diff --git a/clinicadl/utils/split_manager/__init__.py b/clinicadl/validation/split_manager/__init__.py similarity index 100% rename from clinicadl/utils/split_manager/__init__.py rename to clinicadl/validation/split_manager/__init__.py diff --git a/clinicadl/utils/split_manager/kfold.py b/clinicadl/validation/split_manager/kfold.py similarity index 94% rename from clinicadl/utils/split_manager/kfold.py rename to clinicadl/validation/split_manager/kfold.py index 5ee560862..a3c26baaa 100644 --- a/clinicadl/utils/split_manager/kfold.py +++ b/clinicadl/validation/split_manager/kfold.py @@ -1,6 +1,6 @@ from pathlib import Path -from clinicadl.utils.split_manager.split_manager import SplitManager +from clinicadl.validation.split_manager.split_manager import SplitManager class KFoldSplit(SplitManager): diff --git a/clinicadl/utils/split_manager/single_split.py b/clinicadl/validation/split_manager/single_split.py similarity index 92% rename from clinicadl/utils/split_manager/single_split.py rename to clinicadl/validation/split_manager/single_split.py index 1b72f82c7..6ff282bb2 100644 --- a/clinicadl/utils/split_manager/single_split.py +++ b/clinicadl/validation/split_manager/single_split.py @@ -1,6 +1,6 @@ from pathlib import Path -from clinicadl.utils.split_manager.split_manager import SplitManager +from clinicadl.validation.split_manager.split_manager import SplitManager class SingleSplit(SplitManager): diff --git a/clinicadl/utils/split_manager/split_manager.py b/clinicadl/validation/split_manager/split_manager.py similarity index 99% rename from clinicadl/utils/split_manager/split_manager.py rename to clinicadl/validation/split_manager/split_manager.py index 7ddd06161..5696e1571 100644 --- a/clinicadl/utils/split_manager/split_manager.py +++ b/clinicadl/validation/split_manager/split_manager.py @@ -4,12 +4,12 @@ import pandas as pd -from clinicadl.utils.clinica_utils import check_caps_folder from clinicadl.utils.exceptions import ( ClinicaDLArgumentError, ClinicaDLConfigurationError, ClinicaDLTSVError, ) +from clinicadl.utils.iotools.clinica_utils import check_caps_folder logger = getLogger("clinicadl.split_manager") diff --git a/clinicadl/config/config/validation.py b/clinicadl/validation/validation.py similarity index 100% rename from clinicadl/config/config/validation.py rename to clinicadl/validation/validation.py diff --git a/tests/test_resume.py b/tests/test_resume.py index 44af2f6d5..9fde97a45 100644 --- a/tests/test_resume.py +++ b/tests/test_resume.py @@ -6,7 +6,7 @@ import pytest -from clinicadl.utils.maps_manager import MapsManager +from clinicadl.maps_manager.maps_manager import MapsManager from .testing_tools import modify_maps diff --git a/tests/unittests/train/test_utils.py b/tests/unittests/train/test_utils.py index 9cad4dbed..6b33787eb 100644 --- a/tests/unittests/train/test_utils.py +++ b/tests/unittests/train/test_utils.py @@ -186,14 +186,14 @@ ], ) def test_extract_config_from_toml_file(config_file, task, expected_output): - from clinicadl.train.utils import extract_config_from_toml_file + from clinicadl.utils.iotools.train_utils import extract_config_from_toml_file assert extract_config_from_toml_file(config_file, task) == expected_output def test_extract_config_from_toml_file_exceptions(): - from clinicadl.train.utils import extract_config_from_toml_file from clinicadl.utils.exceptions import ClinicaDLConfigurationError + from clinicadl.utils.iotools.train_utils import extract_config_from_toml_file with pytest.raises(ClinicaDLConfigurationError): extract_config_from_toml_file( @@ -206,7 +206,7 @@ def test_merge_cli_and_config_file_options(): import click from click.testing import CliRunner - from clinicadl.train.utils import merge_cli_and_config_file_options + from clinicadl.utils.iotools.train_utils import merge_cli_and_config_file_options @click.command() @click.option("--config_file") diff --git a/tests/unittests/train/trainer/test_training_config.py b/tests/unittests/train/trainer/test_training_config.py index 49bfccadc..158a6d6c2 100644 --- a/tests/unittests/train/trainer/test_training_config.py +++ b/tests/unittests/train/trainer/test_training_config.py @@ -5,11 +5,11 @@ from clinicadl.caps_dataset.data_config import DataConfig from clinicadl.caps_dataset.dataloader_config import DataLoaderConfig -from clinicadl.config.config.cross_validation import CrossValidationConfig from clinicadl.config.config.ssda import SSDAConfig -from clinicadl.config.config.transfer_learning import TransferLearningConfig from clinicadl.network.config import NetworkConfig +from clinicadl.trainer.transfer_learning import TransferLearningConfig from clinicadl.transforms.config import TransformsConfig +from clinicadl.validation.cross_validation import CrossValidationConfig # Tests for customed validators # diff --git a/tests/unittests/utils/caps_dataset/test_data.py b/tests/unittests/utils/caps_dataset/test_data.py index 57b074e74..ccfa30e8d 100644 --- a/tests/unittests/utils/caps_dataset/test_data.py +++ b/tests/unittests/utils/caps_dataset/test_data.py @@ -13,7 +13,7 @@ ], ) def test_check_multi_cohort_tsv(dataframe_columns, purpose): - from clinicadl.caps_dataset.data_utils import check_multi_cohort_tsv + from clinicadl.utils.iotools.data_utils import check_multi_cohort_tsv assert ( check_multi_cohort_tsv(pd.DataFrame(columns=dataframe_columns), purpose) is None @@ -38,8 +38,8 @@ def test_check_multi_cohort_tsv(dataframe_columns, purpose): def test_check_multi_cohort_tsv_errors( dataframe_columns, purpose, expected_mandatory_columns ): - from clinicadl.caps_dataset.data_utils import check_multi_cohort_tsv from clinicadl.utils.exceptions import ClinicaDLTSVError + from clinicadl.utils.iotools.data_utils import check_multi_cohort_tsv with pytest.raises( ClinicaDLTSVError, diff --git a/tests/unittests/utils/test_clinica_utils.py b/tests/unittests/utils/test_clinica_utils.py index 8f08ff24c..7b87ceacb 100644 --- a/tests/unittests/utils/test_clinica_utils.py +++ b/tests/unittests/utils/test_clinica_utils.py @@ -22,7 +22,8 @@ def test_pet_linear_nii( tracer, suvr_reference_region, uncropped_image, expected_pattern ): from clinicadl.caps_dataset.preprocessing.config import PETPreprocessingConfig - from clinicadl.utils.clinica_utils import FileType, pet_linear_nii + from clinicadl.caps_dataset.preprocessing.utils import pet_linear_nii + from clinicadl.utils.iotools.clinica_utils import FileType config = PETPreprocessingConfig( tracer=tracer,