diff --git a/clinicadl/maps_manager/maps_manager.py b/clinicadl/maps_manager/maps_manager.py index d135fbcb4..60aeb88a7 100644 --- a/clinicadl/maps_manager/maps_manager.py +++ b/clinicadl/maps_manager/maps_manager.py @@ -76,7 +76,15 @@ def __init__( ) test_parameters = self.get_parameters() # test_parameters = path_decoder(test_parameters) + from clinicadl.utils.task_manager.task_manager import TaskConfig + self.parameters = add_default_values(test_parameters) + self.task_config = TaskConfig( + network_task=self.network_task, + mode=self.mode, + n_classes=self.output_size, + ) + self.split_name = ( self._check_split_wording() ) # Used only for retro-compatibility @@ -417,9 +425,14 @@ def _check_args(self, parameters): self.parameters["label"] = None from clinicadl.utils.enum import Task - from clinicadl.utils.task_manager.task_manager import get_default_network + from clinicadl.utils.task_manager.task_manager import ( + TaskConfig, + get_default_network, + ) self.network_task = Task(self.parameters["network_task"]) + self.task_config = TaskConfig(self.network_task, self.mode, df=train_df) + # self.task_manager = self._init_task_manager(df=train_df) if self.parameters["architecture"] == "default": self.parameters["architecture"] = get_default_network(self.network_task) diff --git a/clinicadl/utils/task_manager/task_manager.py b/clinicadl/utils/task_manager/task_manager.py index 5045b1231..365f734d2 100644 --- a/clinicadl/utils/task_manager/task_manager.py +++ b/clinicadl/utils/task_manager/task_manager.py @@ -5,6 +5,11 @@ import pandas as pd import torch import torch.distributed as dist +from pydantic import ( + BaseModel, + ConfigDict, + computed_field, +) from torch import Tensor, nn from torch.cuda.amp import autocast from torch.nn.functional import softmax @@ -32,6 +37,8 @@ # elif network_task == Task.RECONSTRUCTION: +# This function is not useful anymore since we introduced config class +# default network will automatically be initialized when running the task def get_default_network(network_task: Task) -> str: # return Network """Returns the default network to use when no architecture is specified.""" if network_task == Task.CLASSIFICATION: @@ -109,7 +116,7 @@ def handle_reconstruction_loss(criterion, compatible_losses): def output_size( network_task: Union[str, Task], - input_size: Sequence[int], + input_size: Optional[Sequence[int]], df: pd.DataFrame, label: str, ) -> Union[int, Sequence[int]]: @@ -719,8 +726,27 @@ def get_sampler(weights): return get_sampler(weights) -class TaskManager: - def __init__(self, mode: str, n_classes: int = None): +class TaskConfig(BaseModel): + mode: str + network_task: Task + n_classe: Optional[int] = None + df: Optional[pd.DataFrame] = None + label: Optional[str] = None + + def __init__( + self, + network_task: Union[str, Task], + mode: str, + n_classes: Optional[int] = None, + df: Optional[pd.DataFrame] = None, + label: Optional[str] = None, + ): + network_task = Task(network_task) + if network_task == Task.CLASSIFICATION: + if n_classes is None and df is not None: + n_classes = output_size(Task.CLASSIFICATION, None, df, label) + self.n_classes = n_classes + self.mode = mode self.metrics_module = MetricModule( evaluation_metrics(network_task), n_classes=n_classes