Skip to content

Commit

Permalink
MapsManager cleaning (#646)
Browse files Browse the repository at this point in the history
  • Loading branch information
camillebrianceau authored Aug 29, 2024
1 parent 66524a3 commit 5487560
Show file tree
Hide file tree
Showing 11 changed files with 307 additions and 182 deletions.
197 changes: 35 additions & 162 deletions clinicadl/maps_manager/maps_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
from clinicadl.caps_dataset.data import (
return_dataset,
)
from clinicadl.metrics.utils import (
check_selection_metric,
find_selection_metrics,
)
from clinicadl.predict.utils import get_prediction
from clinicadl.transforms.config import TransformsConfig
from clinicadl.utils import cluster
from clinicadl.utils.computational.ddp import DDP, init_ddp
Expand Down Expand Up @@ -332,44 +337,6 @@ def _compute_output_tensors(
torch.save(output, tensor_path / output_filename)
logger.debug(f"File saved at {[input_filename, output_filename]}")

def find_splits(self) -> List[int]:
"""Find which splits that were trained in the MAPS."""
splits = [
int(split.name.split("-")[1])
for split in list(self.maps_path.iterdir())
if split.name.startswith(f"{self.split_name}-")
]
return splits

def find_stopped_splits(self) -> List[int]:
"""Find which splits for which training was not completed."""
existing_split_list = self.find_splits()
stopped_splits = [
split
for split in existing_split_list
if (self.maps_path / f"{self.split_name}-{split}" / "tmp")
in list((self.maps_path / f"{self.split_name}-{split}").iterdir())
]
return stopped_splits

def find_finished_splits(self) -> List[int]:
"""Find which splits for which training was completed."""
finished_splits = list()
existing_split_list = self.find_splits()
stopped_splits = self.find_stopped_splits()
for split in existing_split_list:
if split not in stopped_splits:
performance_dir_list = [
performance_dir
for performance_dir in list(
(self.maps_path / f"{self.split_name}-{split}").iterdir()
)
if "best-" in performance_dir.name
]
if len(performance_dir_list) > 0:
finished_splits.append(split)
return finished_splits

def _ensemble_prediction(
self,
data_group,
Expand All @@ -381,7 +348,9 @@ def _ensemble_prediction(
"""Computes the results on the image-level."""

if not selection_metrics:
selection_metrics = self._find_selection_metrics(split)
selection_metrics = find_selection_metrics(
self.maps_path, self.split_name, split
)

for selection_metric in selection_metrics:
#####################
Expand Down Expand Up @@ -495,42 +464,6 @@ def _check_split_wording(self):
else:
return "split"

def _find_selection_metrics(self, split):
"""Find which selection metrics are available in MAPS for a given split."""

split_path = self.maps_path / f"{self.split_name}-{split}"
if not split_path.is_dir():
raise MAPSError(
f"Training of split {split} was not performed."
f"Please execute maps_manager.train(split_list=[{split}])"
)

return [
metric.name.split("-")[1]
for metric in list(split_path.iterdir())
if metric.name[:5:] == "best-"
]

def _check_selection_metric(self, split, selection_metric=None):
"""Check that a given selection metric is available for a given split."""
available_metrics = self._find_selection_metrics(split)

if not selection_metric:
if len(available_metrics) > 1:
raise ClinicaDLArgumentError(
f"Several metrics are available for split {split}. "
f"Please choose which one you want to read among {available_metrics}"
)
else:
selection_metric = available_metrics[0]
else:
if selection_metric not in available_metrics:
raise ClinicaDLArgumentError(
f"The metric {selection_metric} is not available."
f"Please choose among is the available metrics {available_metrics}."
)
return selection_metric

###############################
# File writers #
###############################
Expand Down Expand Up @@ -747,11 +680,23 @@ def _ensemble_to_tsv(
validation_dataset = data_group
else:
validation_dataset = "validation"
test_df = self.get_prediction(
data_group, split, selection, self.mode, verbose=False
test_df = get_prediction(
self.maps_path,
self.split_name,
data_group,
split,
selection,
self.mode,
verbose=False,
)
validation_df = self.get_prediction(
validation_dataset, split, selection, self.mode, verbose=False
validation_df = get_prediction(
self.maps_path,
self.split_name,
validation_dataset,
split,
selection,
self.mode,
verbose=False,
)

performance_dir = (
Expand Down Expand Up @@ -800,8 +745,14 @@ def _mode_to_image_tsv(
use_labels: If True the labels are added to the final tsv
"""
sub_df = self.get_prediction(
data_group, split, selection, self.mode, verbose=False
sub_df = get_prediction(
self.maps_path,
self.split_name,
data_group,
split,
selection,
self.mode,
verbose=False,
)
sub_df.rename(columns={f"{self.mode}_id": "image_id"}, inplace=True)

Expand Down Expand Up @@ -1075,7 +1026,9 @@ def get_state_dict(
-------
(Dict): dictionary of results (weights, epoch number, metrics values)
"""
selection_metric = self._check_selection_metric(split, selection_metric)
selection_metric = check_selection_metric(
self.maps_path, self.split_name, split, selection_metric
)
if self.multi_network:
if network is None:
raise ClinicaDLArgumentError(
Expand Down Expand Up @@ -1103,86 +1056,6 @@ def get_state_dict(
)
return torch.load(model_path, map_location=map_location)

def get_prediction(
self,
data_group: str,
split: int = 0,
selection_metric: Optional[str] = None,
mode: str = "image",
verbose: bool = False,
):
"""
Get the individual predictions for each participant corresponding to one group
of participants identified by its data group.
Args:
data_group (str): name of the data group used for the prediction task.
split (int): Index of the split used for training.
selection_metric (str): Metric used for best weights selection.
mode (str): level of the prediction.
verbose (bool): if True will print associated prediction.log.
Returns:
(DataFrame): Results indexed by columns 'participant_id' and 'session_id' which
identifies the image in the BIDS / CAPS.
"""
selection_metric = self._check_selection_metric(split, selection_metric)
if verbose:
self._print_description_log(data_group, split, selection_metric)
prediction_dir = (
self.maps_path
/ f"{self.split_name}-{split}"
/ f"best-{selection_metric}"
/ data_group
)
if not prediction_dir.is_dir():
raise MAPSError(
f"No prediction corresponding to data group {data_group} was found."
)
df = pd.read_csv(
prediction_dir / f"{data_group}_{mode}_level_prediction.tsv",
sep="\t",
)
df.set_index(["participant_id", "session_id"], inplace=True, drop=True)
return df

def get_metrics(
self,
data_group: str,
split: int = 0,
selection_metric: Optional[str] = None,
mode: str = "image",
verbose: bool = True,
):
"""
Get the metrics corresponding to a group of participants identified by its data_group.
Args:
data_group (str): name of the data group used for the prediction task.
split (int): Index of the split used for training.
selection_metric (str): Metric used for best weights selection.
mode (str): level of the prediction
verbose (bool): if True will print associated prediction.log
Returns:
(dict[str:float]): Values of the metrics
"""
selection_metric = self._check_selection_metric(split, selection_metric)
if verbose:
self._print_description_log(data_group, split, selection_metric)
prediction_dir = (
self.maps_path
/ f"{self.split_name}-{split}"
/ f"best-{selection_metric}"
/ data_group
)
if not prediction_dir.is_dir():
raise MAPSError(
f"No prediction corresponding to data group {data_group} was found."
)
df = pd.read_csv(
prediction_dir / f"{data_group}_{mode}_level_metrics.tsv", sep="\t"
)
return df.to_dict("records")[0]

@property
def std_amp(self) -> bool:
"""
Expand Down
File renamed without changes.
88 changes: 88 additions & 0 deletions clinicadl/metrics/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from pathlib import Path
from typing import List, Optional

import pandas as pd

from clinicadl.splitter.split_utils import print_description_log
from clinicadl.utils.exceptions import ClinicaDLArgumentError, MAPSError


def find_selection_metrics(maps_path: Path, split_name: str, split):
"""Find which selection metrics are available in MAPS for a given split."""

split_path = maps_path / f"{split_name}-{split}"
if not split_path.is_dir():
raise KeyError(
f"Training of split {split} was not performed."
f"Please execute maps_manager.train(split_list=[{split}])"
)

return [
metric.name.split("-")[1]
for metric in list(split_path.iterdir())
if metric.name.startswith("best-")
]


def check_selection_metric(
maps_path: Path, split_name: str, split, selection_metric=None
):
"""Check that a given selection metric is available for a given split."""
available_metrics = find_selection_metrics(maps_path, split_name, split)

if not selection_metric:
if len(available_metrics) > 1:
raise ClinicaDLArgumentError(
f"Several metrics are available for split {split}. "
f"Please choose which one you want to read among {available_metrics}"
)
else:
selection_metric = available_metrics[0]
else:
if selection_metric not in available_metrics:
raise ClinicaDLArgumentError(
f"The metric {selection_metric} is not available."
f"Please choose among is the available metrics {available_metrics}."
)
return selection_metric


def get_metrics(
maps_path: Path,
split_name: str,
data_group: str,
split: int = 0,
selection_metric: Optional[str] = None,
mode: str = "image",
verbose: bool = True,
):
"""
Get the metrics corresponding to a group of participants identified by its data_group.
Args:
data_group (str): name of the data group used for the prediction task.
split (int): Index of the split used for training.
selection_metric (str): Metric used for best weights selection.
mode (str): level of the prediction
verbose (bool): if True will print associated prediction.log
Returns:
(dict[str:float]): Values of the metrics
"""
selection_metric = check_selection_metric(
maps_path, split_name, split, selection_metric
)
if verbose:
print_description_log(
maps_path, split_name, data_group, split, selection_metric
)
prediction_dir = (
maps_path / f"{split_name}-{split}" / f"best-{selection_metric}" / data_group
)
if not prediction_dir.is_dir():
raise MAPSError(
f"No prediction corresponding to data group {data_group} was found."
)
df = pd.read_csv(
prediction_dir / f"{data_group}_{mode}_level_metrics.tsv", sep="\t"
)
return df.to_dict("records")[0]
Loading

0 comments on commit 5487560

Please sign in to comment.