Skip to content

Commit

Permalink
try to create a taskconfig but will be removed after
Browse files Browse the repository at this point in the history
  • Loading branch information
camillebrianceau committed Aug 30, 2024
1 parent ad8f37c commit e027c21
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 4 deletions.
15 changes: 14 additions & 1 deletion clinicadl/maps_manager/maps_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,15 @@ def __init__(
)
test_parameters = self.get_parameters()
# test_parameters = path_decoder(test_parameters)
from clinicadl.utils.task_manager.task_manager import TaskConfig

self.parameters = add_default_values(test_parameters)
self.task_config = TaskConfig(
network_task=self.network_task,
mode=self.mode,
n_classes=self.output_size,
)

self.split_name = (
self._check_split_wording()
) # Used only for retro-compatibility
Expand Down Expand Up @@ -417,9 +425,14 @@ def _check_args(self, parameters):
self.parameters["label"] = None

from clinicadl.utils.enum import Task
from clinicadl.utils.task_manager.task_manager import get_default_network
from clinicadl.utils.task_manager.task_manager import (
TaskConfig,
get_default_network,
)

self.network_task = Task(self.parameters["network_task"])
self.task_config = TaskConfig(self.network_task, self.mode, df=train_df)
# self.task_manager = self._init_task_manager(df=train_df)

if self.parameters["architecture"] == "default":
self.parameters["architecture"] = get_default_network(self.network_task)
Expand Down
32 changes: 29 additions & 3 deletions clinicadl/utils/task_manager/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
import pandas as pd
import torch
import torch.distributed as dist
from pydantic import (
BaseModel,
ConfigDict,
computed_field,
)
from torch import Tensor, nn
from torch.cuda.amp import autocast
from torch.nn.functional import softmax
Expand Down Expand Up @@ -32,6 +37,8 @@
# elif network_task == Task.RECONSTRUCTION:


# This function is not useful anymore since we introduced config class
# default network will automatically be initialized when running the task
def get_default_network(network_task: Task) -> str: # return Network
"""Returns the default network to use when no architecture is specified."""
if network_task == Task.CLASSIFICATION:
Expand Down Expand Up @@ -109,7 +116,7 @@ def handle_reconstruction_loss(criterion, compatible_losses):

def output_size(
network_task: Union[str, Task],
input_size: Sequence[int],
input_size: Optional[Sequence[int]],
df: pd.DataFrame,
label: str,
) -> Union[int, Sequence[int]]:
Expand Down Expand Up @@ -719,8 +726,27 @@ def get_sampler(weights):
return get_sampler(weights)


class TaskManager:
def __init__(self, mode: str, n_classes: int = None):
class TaskConfig(BaseModel):
mode: str
network_task: Task
n_classe: Optional[int] = None
df: Optional[pd.DataFrame] = None
label: Optional[str] = None

def __init__(
self,
network_task: Union[str, Task],
mode: str,
n_classes: Optional[int] = None,
df: Optional[pd.DataFrame] = None,
label: Optional[str] = None,
):
network_task = Task(network_task)
if network_task == Task.CLASSIFICATION:
if n_classes is None and df is not None:
n_classes = output_size(Task.CLASSIFICATION, None, df, label)
self.n_classes = n_classes

self.mode = mode
self.metrics_module = MetricModule(
evaluation_metrics(network_task), n_classes=n_classes
Expand Down

0 comments on commit e027c21

Please sign in to comment.