Skip to content

Commit

Permalink
Predict and interpret adaptation to data class (#586)
Browse files Browse the repository at this point in the history
* adapt predict and interpret to config classes
  • Loading branch information
camillebrianceau authored May 30, 2024
1 parent c768094 commit c278622
Show file tree
Hide file tree
Showing 63 changed files with 375 additions and 1,128 deletions.
2 changes: 0 additions & 2 deletions clinicadl/config/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from .data import DataConfig
from .dataloader import DataLoaderConfig
from .early_stopping import EarlyStoppingConfig
from .interpret import InterpretConfig
from .lr_scheduler import LRschedulerConfig
from .maps_manager import MapsManagerConfig
from .modality import (
Expand All @@ -17,7 +16,6 @@
from .model import ModelConfig
from .optimization import OptimizationConfig
from .optimizer import OptimizerConfig
from .predict import PredictConfig
from .preprocessing import (
PreprocessingConfig,
PreprocessingImageConfig,
Expand Down
16 changes: 15 additions & 1 deletion clinicadl/config/config/computational.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from logging import getLogger

from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, model_validator
from typing_extensions import Self

from clinicadl.utils.cmdline_utils import check_gpu
from clinicadl.utils.exceptions import ClinicaDLArgumentError

logger = getLogger("clinicadl.computational_config")

Expand All @@ -13,3 +17,13 @@ class ComputationalConfig(BaseModel):
gpu: bool = True
# pydantic config
model_config = ConfigDict(validate_assignment=True)

@model_validator(mode="after")
def validator_gpu(self) -> Self:
if self.gpu:
check_gpu()
elif self.amp:
raise ClinicaDLArgumentError(
"AMP is designed to work with modern GPUs. Please add the --gpu flag."
)
return self
10 changes: 9 additions & 1 deletion clinicadl/config/config/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from pydantic import BaseModel, ConfigDict, field_validator
from pydantic.types import NonNegativeInt

from clinicadl.utils.maps_manager.maps_manager import MapsManager

logger = getLogger("clinicadl.cross_validation_config")


Expand All @@ -19,7 +21,7 @@ class CrossValidationConfig(

n_splits: NonNegativeInt = 0
split: Optional[Tuple[NonNegativeInt, ...]] = None
tsv_directory: Path
tsv_directory: Optional[Path] = None # not needed in predict ?
# pydantic config
model_config = ConfigDict(validate_assignment=True)

Expand All @@ -28,3 +30,9 @@ def validator_split(cls, v):
if isinstance(v, list):
return tuple(v)
return v # TODO : check that split exists (and check coherence with n_splits)

def adapt_cross_val_with_maps_manager_info(self, maps_manager: MapsManager):
# TEMPORARY
if not self.split:
self.split = maps_manager._find_splits()
logger.debug(f"List of splits {self.split}")
8 changes: 7 additions & 1 deletion clinicadl/config/config/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from clinicadl.utils.caps_dataset.data import load_data_test
from clinicadl.utils.enum import Mode
from clinicadl.utils.maps_manager.maps_manager import MapsManager
from clinicadl.utils.preprocessing import read_preprocessing

logger = getLogger("clinicadl.data_config")
Expand All @@ -24,12 +25,17 @@ class DataConfig(BaseModel): # TODO : put in data module
label: Optional[str] = None
label_code: Dict[str, int] = {}
multi_cohort: bool = False
preprocessing_json: Path
preprocessing_json: Optional[Path] = None
data_tsv: Optional[Path] = None
n_subjects: int = 300
# pydantic config
model_config = ConfigDict(validate_assignment=True)

def adapt_data_with_maps_manager_info(self, maps_manager: MapsManager):
# TEMPORARY
if self.diagnoses is None or len(self.diagnoses) == 0:
self.diagnoses = maps_manager.diagnoses

def create_groupe_df(self):
group_df = None
if self.data_tsv is not None and self.data_tsv.is_file():
Expand Down
9 changes: 9 additions & 0 deletions clinicadl/config/config/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pydantic.types import PositiveInt

from clinicadl.utils.enum import Sampler
from clinicadl.utils.maps_manager.maps_manager import MapsManager

logger = getLogger("clinicadl.dataloader_config")

Expand All @@ -16,3 +17,11 @@ class DataLoaderConfig(BaseModel): # TODO : put in data/splitter module
sampler: Sampler = Sampler.RANDOM
# pydantic config
model_config = ConfigDict(validate_assignment=True)

def adapt_dataloader_with_maps_manager_info(self, maps_manager: MapsManager):
# TEMPORARY
if not self.batch_size:
self.batch_size = maps_manager.batch_size

if not self.n_proc:
self.n_proc = maps_manager.n_proc
14 changes: 0 additions & 14 deletions clinicadl/config/config/maps_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,3 @@ def check_output_saving_nifti(self, network_task: str) -> None:
raise ClinicaDLArgumentError(
"Cannot save nifti if the network task is not reconstruction. Please remove --save_nifti option."
)

def adapt_config_with_maps_manager_info(self, maps_manager: MapsManager):
if not self.split_list:
self.split_list = maps_manager._find_splits()
logger.debug(f"List of splits {self.split_list}")

if self.diagnoses is None or len(self.diagnoses) == 0:
self.diagnoses = maps_manager.diagnoses

if not self.batch_size:
self.batch_size = maps_manager.batch_size

if not self.n_proc:
self.n_proc = maps_manager.n_proc
Empty file.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,28 @@

from pydantic import BaseModel, field_validator

from clinicadl.config.config import (
ComputationalConfig,
CrossValidationConfig,
DataLoaderConfig,
MapsManagerConfig,
ValidationConfig,
)
from clinicadl.config.config import DataConfig as DataBaseConfig
from clinicadl.interpret.gradients import GradCam, Gradients, VanillaBackProp
from clinicadl.utils.caps_dataset.data import (
load_data_test,
)
from clinicadl.utils.enum import InterpretationMethod
from clinicadl.utils.exceptions import ClinicaDLArgumentError # type: ignore
from clinicadl.utils.maps_manager.maps_manager import MapsManager # type: ignore

logger = getLogger("clinicadl.predict_config")
logger = getLogger("clinicadl.interpret_config")


class DataConfig(DataBaseConfig):
caps_directory: Optional[Path] = None


class InterpretConfig(BaseModel):
class InterpretBaseConfig(BaseModel):
name: str
method: InterpretationMethod = InterpretationMethod.GRADIENTS
target_node: int = 0
Expand All @@ -38,3 +48,15 @@ def get_method(self) -> Gradients:
return GradCam
else:
raise ValueError(f"The method {self.method.value} is not implemented")


class InterpretConfig(
MapsManagerConfig,
InterpretBaseConfig,
DataConfig,
ValidationConfig,
CrossValidationConfig,
ComputationalConfig,
DataLoaderConfig,
):
"""Config class to perform Transfer Learning."""
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,42 @@

from pydantic import BaseModel

from clinicadl.config.config.data import DataConfig as DataBaseConfig
from clinicadl.config.config.maps_manager import (
MapsManagerConfig as MapsManagerBaseConfig,
)
from clinicadl.utils.exceptions import ClinicaDLArgumentError # type: ignore

from ..computational import ComputationalConfig
from ..cross_validation import CrossValidationConfig
from ..dataloader import DataLoaderConfig
from ..validation import ValidationConfig

logger = getLogger("clinicadl.predict_config")


class PredictConfig(BaseModel):
class MapsManagerConfig(MapsManagerBaseConfig):
save_tensor: bool = False
save_latent_tensor: bool = False
use_labels: bool = True

def check_output_saving_tensor(self, network_task: str) -> None:
# Check if task is reconstruction for "save_tensor" and "save_nifti"
if self.save_tensor and network_task != "reconstruction":
raise ClinicaDLArgumentError(
"Cannot save tensors if the network task is not reconstruction. Please remove --save_tensor option."
)


class DataConfig(DataBaseConfig):
use_labels: bool = True


class PredictConfig(
MapsManagerConfig,
DataConfig,
ValidationConfig,
CrossValidationConfig,
ComputationalConfig,
DataLoaderConfig,
):
"""Config class to perform Transfer Learning."""
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from clinicadl.config.config import DataConfig as BaseDataConfig
from clinicadl.config.config import ModelConfig as BaseModelConfig
from clinicadl.config.config import ValidationConfig as BaseValidationConfig
from clinicadl.train.trainer.training_config import TrainingConfig
from clinicadl.config.config.pipelines.train import TrainConfig
from clinicadl.utils.enum import ClassificationLoss, ClassificationMetric, Task

logger = getLogger("clinicadl.classification_config")
Expand Down Expand Up @@ -57,7 +57,7 @@ def list_to_tuples(cls, v):
return v


class ClassificationConfig(TrainingConfig):
class ClassificationConfig(TrainConfig):
"""
Config class for the training of a classification model.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from clinicadl.config.config import ModelConfig as BaseModelConfig
from clinicadl.config.config import ValidationConfig as BaseValidationConfig
from clinicadl.train.trainer.training_config import TrainingConfig
from clinicadl.config.config.pipelines.train import TrainConfig
from clinicadl.utils.enum import (
Normalization,
ReconstructionLoss,
Expand Down Expand Up @@ -47,7 +47,7 @@ def list_to_tuples(cls, v):
return v


class ReconstructionConfig(TrainingConfig):
class ReconstructionConfig(TrainConfig):
"""
Config class for the training of a reconstruction model.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from clinicadl.config.config import DataConfig as BaseDataConfig
from clinicadl.config.config import ModelConfig as BaseModelConfig
from clinicadl.config.config import ValidationConfig as BaseValidationConfig
from clinicadl.train.trainer.training_config import TrainingConfig
from clinicadl.config.config.pipelines.train import TrainConfig
from clinicadl.utils.enum import RegressionLoss, RegressionMetric, Task

logger = getLogger("clinicadl.reconstruction_config")
Expand Down Expand Up @@ -47,7 +47,7 @@ def list_to_tuples(cls, v):
return v


class RegressionConfig(TrainingConfig):
class RegressionConfig(TrainConfig):
"""
Config class for the training of a regression model.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
logger = getLogger("clinicadl.training_config")


class TrainingConfig(BaseModel, ABC):
class TrainConfig(BaseModel, ABC):
"""
Abstract config class for the training pipeline.
Expand Down
2 changes: 1 addition & 1 deletion clinicadl/config/options/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .task import classification, reconstruction, regression
# from .task import classification, reconstruction, regression
9 changes: 4 additions & 5 deletions clinicadl/config/options/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
import click

import clinicadl.train.trainer.training_config as config
from clinicadl.config import config
from clinicadl.config.config.callbacks import CallbacksConfig
from clinicadl.utils.config_utils import get_default_from_config_class as get_default
from clinicadl.utils.config_utils import get_type_from_config_class as get_type

emissions_calculator = click.option(
"--calculate_emissions/--dont_calculate_emissions",
default=get_default("emissions_calculator", config.CallbacksConfig),
default=get_default("emissions_calculator", CallbacksConfig),
help="Flag to allow calculate the carbon emissions during training.",
show_default=True,
)
track_exp = click.option(
"--track_exp",
"-te",
type=get_type("track_exp", config.CallbacksConfig),
default=get_default("track_exp", config.CallbacksConfig),
type=get_type("track_exp", CallbacksConfig),
default=get_default("track_exp", CallbacksConfig),
help="Use `--track_exp` to enable wandb/mlflow to track the metric (loss, accuracy, etc...) during the training.",
show_default=True,
)
6 changes: 3 additions & 3 deletions clinicadl/config/options/computational.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import click

from clinicadl.config import config
from clinicadl.config.config.computational import ComputationalConfig
from clinicadl.utils.config_utils import get_default_from_config_class as get_default
from clinicadl.utils.config_utils import get_type_from_config_class as get_type

# Computational
amp = click.option(
"--amp/--no-amp",
default=get_default("amp", config.ComputationalConfig),
default=get_default("amp", ComputationalConfig),
help="Enables automatic mixed precision during training and inference.",
show_default=True,
)
Expand All @@ -21,7 +21,7 @@
)
gpu = click.option(
"--gpu/--no-gpu",
default=get_default("gpu", config.ComputationalConfig),
default=get_default("gpu", ComputationalConfig),
help="Use GPU by default. Please specify `--no-gpu` to force using CPU.",
show_default=True,
)
9 changes: 4 additions & 5 deletions clinicadl/config/options/cross_validation.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import click

import clinicadl.train.trainer.training_config as config
from clinicadl.config import config
from clinicadl.config.config.cross_validation import CrossValidationConfig
from clinicadl.utils.config_utils import get_default_from_config_class as get_default
from clinicadl.utils.config_utils import get_type_from_config_class as get_type

# Cross Validation
n_splits = click.option(
"--n_splits",
type=get_type("n_splits", config.CrossValidationConfig),
default=get_default("n_splits", config.CrossValidationConfig),
type=get_type("n_splits", CrossValidationConfig),
default=get_default("n_splits", CrossValidationConfig),
help="If a value is given for k will load data of a k-fold CV. "
"Default value (0) will load a single split.",
show_default=True,
Expand All @@ -18,7 +17,7 @@
"--split",
"-s",
type=int, # get_type("split", config.CrossValidationConfig),
default=get_default("split", config.CrossValidationConfig),
default=get_default("split", CrossValidationConfig),
multiple=True,
help="Train the list of given splits. By default, all the splits are trained.",
show_default=True,
Expand Down
Loading

0 comments on commit c278622

Please sign in to comment.