diff --git a/clinicadl/caps_dataset/extraction/config.py b/clinicadl/caps_dataset/extraction/config.py index f33b5b9fe..bb4bed89a 100644 --- a/clinicadl/caps_dataset/extraction/config.py +++ b/clinicadl/caps_dataset/extraction/config.py @@ -1,4 +1,5 @@ from logging import getLogger +from pathlib import Path from time import time from typing import List, Optional, Tuple, Union @@ -31,6 +32,10 @@ class ExtractionConfig(BaseModel): @field_validator("extract_json", mode="before") def compute_extract_json(cls, v: str): + if isinstance(v, Path): + v = str(v) + elif isinstance(v, bool): + v = None if v is None: return f"extract_{int(time())}.json" elif not v.endswith(".json"): @@ -75,3 +80,11 @@ class ExtractionROIConfig(ExtractionConfig): roi_custom_mask_pattern: str = "" roi_background_value: int = 0 extract_method: ExtractionMethod = ExtractionMethod.ROI + + +ALL_EXTRACTION_TYPES = Union[ + ExtractionImageConfig, + ExtractionROIConfig, + ExtractionSliceConfig, + ExtractionPatchConfig, +] diff --git a/clinicadl/caps_dataset/preprocessing/config.py b/clinicadl/caps_dataset/preprocessing/config.py index 447d1986f..fae4fa777 100644 --- a/clinicadl/caps_dataset/preprocessing/config.py +++ b/clinicadl/caps_dataset/preprocessing/config.py @@ -1,7 +1,7 @@ import abc from logging import getLogger from pathlib import Path -from typing import Optional, Tuple +from typing import Optional, Tuple, Union from pydantic import BaseModel, ConfigDict @@ -207,3 +207,13 @@ def caps_nii(self) -> tuple: def get_filetype(self) -> FileType: return self.linear_nii() + + +ALL_PREPROCESSING_TYPES = Union[ + T1PreprocessingConfig, + T2PreprocessingConfig, + FlairPreprocessingConfig, + PETPreprocessingConfig, + CustomPreprocessingConfig, + DTIPreprocessingConfig, +] diff --git a/clinicadl/caps_dataset/utils.py b/clinicadl/caps_dataset/utils.py index 207af033f..7f7503997 100644 --- a/clinicadl/caps_dataset/utils.py +++ b/clinicadl/caps_dataset/utils.py @@ -61,8 +61,9 @@ def get_preprocessing_and_mode_from_parameters(**kwargs): if "preprocessing_dict" in kwargs: kwargs = kwargs["preprocessing_dict"] + print(kwargs) preprocessing = Preprocessing(kwargs["preprocessing"]) - mode = ExtractionMethod(kwargs["mode"]) + mode = ExtractionMethod(kwargs["extract_method"]) return get_preprocessing(preprocessing)(**kwargs), get_extraction(mode)(**kwargs) diff --git a/clinicadl/maps_manager/maps_manager.py b/clinicadl/maps_manager/maps_manager.py index 1da78b29a..4237083b3 100644 --- a/clinicadl/maps_manager/maps_manager.py +++ b/clinicadl/maps_manager/maps_manager.py @@ -154,7 +154,6 @@ def _check_args(self, parameters): mandatory_arguments = [ "caps_directory", "tsv_path", - "mode", "network_task", ] for arg in mandatory_arguments: diff --git a/clinicadl/trainer/config/train.py b/clinicadl/trainer/config/train.py index 30a92c92a..5e0f0210d 100644 --- a/clinicadl/trainer/config/train.py +++ b/clinicadl/trainer/config/train.py @@ -12,6 +12,8 @@ from clinicadl.callbacks.config import CallbacksConfig from clinicadl.caps_dataset.data_config import DataConfig from clinicadl.caps_dataset.dataloader_config import DataLoaderConfig +from clinicadl.caps_dataset.extraction.config import ALL_EXTRACTION_TYPES +from clinicadl.caps_dataset.preprocessing.config import ALL_PREPROCESSING_TYPES from clinicadl.config.config.lr_scheduler import LRschedulerConfig from clinicadl.config.config.reproducibility import ReproducibilityConfig from clinicadl.maps_manager.config import MapsManagerConfig @@ -42,9 +44,11 @@ class TrainConfig(BaseModel, ABC): data: DataConfig dataloader: DataLoaderConfig early_stopping: EarlyStoppingConfig + extraction: ALL_EXTRACTION_TYPES lr_scheduler: LRschedulerConfig maps_manager: MapsManagerConfig model: NetworkConfig + preprocessing: ALL_PREPROCESSING_TYPES optimization: OptimizationConfig optimizer: OptimizerConfig reproducibility: ReproducibilityConfig @@ -55,6 +59,9 @@ class TrainConfig(BaseModel, ABC): # pydantic config model_config = ConfigDict(validate_assignment=True) + # @field_validator("preprocessing", mode="before") + # def check_preprocessing(cls, v: str): + @computed_field @property @abstractmethod @@ -68,9 +75,11 @@ def __init__(self, **kwargs): data=kwargs, dataloader=kwargs, early_stopping=kwargs, + extraction=kwargs, lr_scheduler=kwargs, maps_manager=kwargs, model=kwargs, + preprocessing=kwargs, optimization=kwargs, optimizer=kwargs, reproducibility=kwargs, @@ -87,9 +96,11 @@ def _update(self, config_dict: Dict[str, Any]) -> None: self.data.__dict__.update(config_dict) self.dataloader.__dict__.update(config_dict) self.early_stopping.__dict__.update(config_dict) + self.extraction.__dict__.update(config_dict) self.lr_scheduler.__dict__.update(config_dict) self.maps_manager.__dict__.update(config_dict) self.model.__dict__.update(config_dict) + self.preprocessing.__dict__.update(config_dict) self.optimization.__dict__.update(config_dict) self.optimizer.__dict__.update(config_dict) self.reproducibility.__dict__.update(config_dict) diff --git a/clinicadl/trainer/trainer.py b/clinicadl/trainer/trainer.py index 7551f0241..a8f9fe9c7 100644 --- a/clinicadl/trainer/trainer.py +++ b/clinicadl/trainer/trainer.py @@ -125,8 +125,11 @@ def from_json( config_dict = patch_to_read_json(read_json(config_file)) # TODO : remove patch config_dict["maps_dir"] = maps_path config_dict["split"] = split if split else () + + from clinicadl.utils.iotools.trainer_utils import read_multi_level_dict + config_object = create_training_config(config_dict["network_task"])( - **config_dict + **read_multi_level_dict(config_dict) ) return cls(config_object) diff --git a/clinicadl/utils/iotools/trainer_utils.py b/clinicadl/utils/iotools/trainer_utils.py index ac1b6a3bf..05ce45352 100644 --- a/clinicadl/utils/iotools/trainer_utils.py +++ b/clinicadl/utils/iotools/trainer_utils.py @@ -1,6 +1,16 @@ from pathlib import Path +def read_multi_level_dict(dict_): + parameters = {} + for key in dict_: + if isinstance(dict_[key], dict): + parameters.update(dict_[key]) + else: + parameters[key] = dict_[key] + return parameters + + def create_parameters_dict(config): parameters = {} config_dict = config.model_dump() @@ -20,7 +30,6 @@ def create_parameters_dict(config): if parameters["data_augmentation"] == (): parameters["data_augmentation"] = False - del parameters["preprocessing_json"] # if "tsv_path" in parameters: # parameters["tsv_path"] = parameters["tsv_path"] # del parameters["tsv_path"]