Skip to content

Commit

Permalink
Some cleaning with utils files (#633)
Browse files Browse the repository at this point in the history
* clean
Co-authored-by: thibaultdvx <[email protected]>
  • Loading branch information
camillebrianceau authored Jun 26, 2024
1 parent 7e35b44 commit 64a8958
Show file tree
Hide file tree
Showing 93 changed files with 537 additions and 617 deletions.
6 changes: 3 additions & 3 deletions clinicadl/caps_dataset/caps_dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@
PreprocessingConfig,
T1PreprocessingConfig,
)
from clinicadl.transforms.config import TransformsConfig
from clinicadl.utils.clinica_utils import (
FileType,
from clinicadl.caps_dataset.preprocessing.utils import (
bids_nii,
dwi_dti,
linear_nii,
pet_linear_nii,
)
from clinicadl.transforms.config import TransformsConfig
from clinicadl.utils.enum import ExtractionMethod, Preprocessing
from clinicadl.utils.iotools.clinica_utils import FileType


def get_extraction(extract_method: ExtractionMethod):
Expand Down
143 changes: 140 additions & 3 deletions clinicadl/caps_dataset/caps_dataset_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
from pathlib import Path
from typing import Optional, Tuple
from typing import Any, Dict, Optional, Tuple

from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig
from clinicadl.caps_dataset.preprocessing.config import (
Expand All @@ -9,14 +10,15 @@
PETPreprocessingConfig,
T1PreprocessingConfig,
)
from clinicadl.utils.clinica_utils import (
FileType,
from clinicadl.caps_dataset.preprocessing.utils import (
bids_nii,
dwi_dti,
linear_nii,
pet_linear_nii,
)
from clinicadl.utils.enum import Preprocessing
from clinicadl.utils.exceptions import ClinicaDLArgumentError
from clinicadl.utils.iotools.clinica_utils import FileType


def compute_folder_and_file_type(
Expand Down Expand Up @@ -54,3 +56,138 @@ def compute_folder_and_file_type(
description="Custom suffix",
)
return mod_subfolder, file_type


def find_file_type(config: CapsDatasetConfig) -> FileType:
if isinstance(config.preprocessing, T1PreprocessingConfig):
file_type = linear_nii(config.preprocessing)
elif isinstance(config.preprocessing, PETPreprocessingConfig):
if (
config.preprocessing.tracer is None
or config.preprocessing.suvr_reference_region is None
):
raise ClinicaDLArgumentError(
"`tracer` and `suvr_reference_region` must be defined "
"when using `pet-linear` preprocessing."
)
file_type = pet_linear_nii(config.preprocessing)
else:
raise NotImplementedError(
f"Generation of synthetic data is not implemented for preprocessing {config.preprocessing.preprocessing.value}"
)

return file_type


def read_json(json_path: Path) -> Dict[str, Any]:
"""
Ensures retro-compatibility between the different versions of ClinicaDL.
Parameters
----------
json_path: Path
path to the JSON file summing the parameters of a MAPS.
Returns
-------
A dictionary of training parameters.
"""
from clinicadl.utils.iotools.utils import path_decoder

with json_path.open(mode="r") as f:
parameters = json.load(f, object_hook=path_decoder)
# Types of retro-compatibility
# Change arg name: ex network --> model
# Change arg value: ex for preprocessing: mni --> t1-extensive
# New arg with default hard-coded value --> discarded_slice --> 20
retro_change_name = {
"model": "architecture",
"multi": "multi_network",
"minmaxnormalization": "normalize",
"num_workers": "n_proc",
"mode": "extract_method",
}

retro_add = {
"optimizer": "Adam",
"loss": None,
}

for old_name, new_name in retro_change_name.items():
if old_name in parameters:
parameters[new_name] = parameters[old_name]
del parameters[old_name]

for name, value in retro_add.items():
if name not in parameters:
parameters[name] = value

if "extract_method" in parameters:
parameters["mode"] = parameters["extract_method"]
# Value changes
if "use_cpu" in parameters:
parameters["gpu"] = not parameters["use_cpu"]
del parameters["use_cpu"]
if "nondeterministic" in parameters:
parameters["deterministic"] = not parameters["nondeterministic"]
del parameters["nondeterministic"]

# Build preprocessing_dict
if "preprocessing_dict" not in parameters:
parameters["preprocessing_dict"] = {"mode": parameters["mode"]}
preprocessing_options = [
"preprocessing",
"use_uncropped_image",
"prepare_dl",
"custom_suffix",
"tracer",
"suvr_reference_region",
"patch_size",
"stride_size",
"slice_direction",
"slice_mode",
"discarded_slices",
"roi_list",
"uncropped_roi",
"roi_custom_suffix",
"roi_custom_template",
"roi_custom_mask_pattern",
]
for preprocessing_var in preprocessing_options:
if preprocessing_var in parameters:
parameters["preprocessing_dict"][preprocessing_var] = parameters[
preprocessing_var
]
del parameters[preprocessing_var]

# Add missing parameters in previous version of extract
if "use_uncropped_image" not in parameters["preprocessing_dict"]:
parameters["preprocessing_dict"]["use_uncropped_image"] = False

if (
"prepare_dl" not in parameters["preprocessing_dict"]
and parameters["mode"] != "image"
):
parameters["preprocessing_dict"]["prepare_dl"] = False

if (
parameters["mode"] == "slice"
and "slice_mode" not in parameters["preprocessing_dict"]
):
parameters["preprocessing_dict"]["slice_mode"] = "rgb"

if "preprocessing" not in parameters:
parameters["preprocessing"] = parameters["preprocessing_dict"]["preprocessing"]

from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig

config = CapsDatasetConfig.from_preprocessing_and_extraction_method(
extraction=parameters["mode"],
preprocessing_type=parameters["preprocessing"],
**parameters,
)
if "file_type" not in parameters["preprocessing_dict"]:
_, file_type = compute_folder_and_file_type(config)
parameters["preprocessing_dict"]["file_type"] = file_type.model_dump()

return parameters
9 changes: 3 additions & 6 deletions clinicadl/caps_dataset/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@
import abc
from logging import getLogger
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset

from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig
from clinicadl.caps_dataset.caps_dataset_utils import compute_folder_and_file_type
from clinicadl.caps_dataset.extraction.config import (
ExtractionImageConfig,
ExtractionPatchConfig,
Expand All @@ -30,15 +29,13 @@
)
from clinicadl.transforms.config import TransformsConfig
from clinicadl.utils.enum import (
ExtractionMethod,
Pattern,
Preprocessing,
SliceDirection,
SliceMode,
Template,
)
from clinicadl.utils.exceptions import (
ClinicaDLArgumentError,
ClinicaDLCAPSError,
ClinicaDLTSVError,
)
Expand Down Expand Up @@ -133,7 +130,7 @@ def _get_image_path(self, participant: str, session: str, cohort: str) -> Path:
Returns:
image_path: path to the tensor containing the whole image.
"""
from clinicadl.utils.clinica_utils import clinicadl_file_reader
from clinicadl.utils.iotools.clinica_utils import clinicadl_file_reader

# Try to find .nii.gz file
try:
Expand Down Expand Up @@ -221,7 +218,7 @@ def _get_full_image(self) -> torch.Tensor:
"""
import nibabel as nib

from clinicadl.utils.clinica_utils import clinicadl_file_reader
from clinicadl.utils.iotools.clinica_utils import clinicadl_file_reader

participant_id = self.df.loc[0, "participant_id"]
session_id = self.df.loc[0, "session_id"]
Expand Down
6 changes: 3 additions & 3 deletions clinicadl/caps_dataset/data_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import pandas as pd
from pydantic import BaseModel, ConfigDict, computed_field, field_validator

from clinicadl.caps_dataset.data_utils import check_multi_cohort_tsv, load_data_test
from clinicadl.caps_dataset.extraction.utils import read_preprocessing
from clinicadl.utils.enum import Mode
from clinicadl.utils.exceptions import (
ClinicaDLArgumentError,
ClinicaDLTSVError,
)
from clinicadl.utils.iotools.data_utils import check_multi_cohort_tsv, load_data_test
from clinicadl.utils.iotools.utils import read_preprocessing

logger = getLogger("clinicadl.data_config")

Expand Down Expand Up @@ -85,7 +85,7 @@ def check_data_tsv(cls, v) -> Path:
@computed_field
@property
def caps_dict(self) -> Dict[str, Path]:
from clinicadl.utils.clinica_utils import check_caps_folder
from clinicadl.utils.iotools.clinica_utils import check_caps_folder

if self.multi_cohort:
if self.caps_directory.suffix != ".tsv":
Expand Down
2 changes: 1 addition & 1 deletion clinicadl/caps_dataset/extraction/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from pydantic import BaseModel, ConfigDict, field_validator
from pydantic.types import NonNegativeInt

from clinicadl.utils.clinica_utils import FileType
from clinicadl.utils.enum import (
ExtractionMethod,
SliceDirection,
SliceMode,
)
from clinicadl.utils.iotools.clinica_utils import FileType

logger = getLogger("clinicadl.preprocessing_config")

Expand Down
Loading

0 comments on commit 64a8958

Please sign in to comment.