Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MapsManager cleaning #646

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 35 additions & 162 deletions clinicadl/maps_manager/maps_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
from clinicadl.caps_dataset.data import (
return_dataset,
)
from clinicadl.metrics.utils import (
check_selection_metric,
find_selection_metrics,
)
from clinicadl.predict.utils import get_prediction
from clinicadl.transforms.config import TransformsConfig
from clinicadl.utils import cluster
from clinicadl.utils.computational.ddp import DDP, init_ddp
Expand Down Expand Up @@ -332,44 +337,6 @@ def _compute_output_tensors(
torch.save(output, tensor_path / output_filename)
logger.debug(f"File saved at {[input_filename, output_filename]}")

def find_splits(self) -> List[int]:
"""Find which splits that were trained in the MAPS."""
splits = [
int(split.name.split("-")[1])
for split in list(self.maps_path.iterdir())
if split.name.startswith(f"{self.split_name}-")
]
return splits

def find_stopped_splits(self) -> List[int]:
"""Find which splits for which training was not completed."""
existing_split_list = self.find_splits()
stopped_splits = [
split
for split in existing_split_list
if (self.maps_path / f"{self.split_name}-{split}" / "tmp")
in list((self.maps_path / f"{self.split_name}-{split}").iterdir())
]
return stopped_splits

def find_finished_splits(self) -> List[int]:
"""Find which splits for which training was completed."""
finished_splits = list()
existing_split_list = self.find_splits()
stopped_splits = self.find_stopped_splits()
for split in existing_split_list:
if split not in stopped_splits:
performance_dir_list = [
performance_dir
for performance_dir in list(
(self.maps_path / f"{self.split_name}-{split}").iterdir()
)
if "best-" in performance_dir.name
]
if len(performance_dir_list) > 0:
finished_splits.append(split)
return finished_splits

def _ensemble_prediction(
self,
data_group,
Expand All @@ -381,7 +348,9 @@ def _ensemble_prediction(
"""Computes the results on the image-level."""

if not selection_metrics:
selection_metrics = self._find_selection_metrics(split)
selection_metrics = find_selection_metrics(
self.maps_path, self.split_name, split
)

for selection_metric in selection_metrics:
#####################
Expand Down Expand Up @@ -495,42 +464,6 @@ def _check_split_wording(self):
else:
return "split"

def _find_selection_metrics(self, split):
"""Find which selection metrics are available in MAPS for a given split."""

split_path = self.maps_path / f"{self.split_name}-{split}"
if not split_path.is_dir():
raise MAPSError(
f"Training of split {split} was not performed."
f"Please execute maps_manager.train(split_list=[{split}])"
)

return [
metric.name.split("-")[1]
for metric in list(split_path.iterdir())
if metric.name[:5:] == "best-"
]

def _check_selection_metric(self, split, selection_metric=None):
"""Check that a given selection metric is available for a given split."""
available_metrics = self._find_selection_metrics(split)

if not selection_metric:
if len(available_metrics) > 1:
raise ClinicaDLArgumentError(
f"Several metrics are available for split {split}. "
f"Please choose which one you want to read among {available_metrics}"
)
else:
selection_metric = available_metrics[0]
else:
if selection_metric not in available_metrics:
raise ClinicaDLArgumentError(
f"The metric {selection_metric} is not available."
f"Please choose among is the available metrics {available_metrics}."
)
return selection_metric

###############################
# File writers #
###############################
Expand Down Expand Up @@ -747,11 +680,23 @@ def _ensemble_to_tsv(
validation_dataset = data_group
else:
validation_dataset = "validation"
test_df = self.get_prediction(
data_group, split, selection, self.mode, verbose=False
test_df = get_prediction(
self.maps_path,
self.split_name,
data_group,
split,
selection,
self.mode,
verbose=False,
)
validation_df = self.get_prediction(
validation_dataset, split, selection, self.mode, verbose=False
validation_df = get_prediction(
self.maps_path,
self.split_name,
validation_dataset,
split,
selection,
self.mode,
verbose=False,
)

performance_dir = (
Expand Down Expand Up @@ -800,8 +745,14 @@ def _mode_to_image_tsv(
use_labels: If True the labels are added to the final tsv

"""
sub_df = self.get_prediction(
data_group, split, selection, self.mode, verbose=False
sub_df = get_prediction(
self.maps_path,
self.split_name,
data_group,
split,
selection,
self.mode,
verbose=False,
)
sub_df.rename(columns={f"{self.mode}_id": "image_id"}, inplace=True)

Expand Down Expand Up @@ -1075,7 +1026,9 @@ def get_state_dict(
-------
(Dict): dictionary of results (weights, epoch number, metrics values)
"""
selection_metric = self._check_selection_metric(split, selection_metric)
selection_metric = check_selection_metric(
self.maps_path, self.split_name, split, selection_metric
)
if self.multi_network:
if network is None:
raise ClinicaDLArgumentError(
Expand Down Expand Up @@ -1103,86 +1056,6 @@ def get_state_dict(
)
return torch.load(model_path, map_location=map_location)

def get_prediction(
self,
data_group: str,
split: int = 0,
selection_metric: Optional[str] = None,
mode: str = "image",
verbose: bool = False,
):
"""
Get the individual predictions for each participant corresponding to one group
of participants identified by its data group.

Args:
data_group (str): name of the data group used for the prediction task.
split (int): Index of the split used for training.
selection_metric (str): Metric used for best weights selection.
mode (str): level of the prediction.
verbose (bool): if True will print associated prediction.log.
Returns:
(DataFrame): Results indexed by columns 'participant_id' and 'session_id' which
identifies the image in the BIDS / CAPS.
"""
selection_metric = self._check_selection_metric(split, selection_metric)
if verbose:
self._print_description_log(data_group, split, selection_metric)
prediction_dir = (
self.maps_path
/ f"{self.split_name}-{split}"
/ f"best-{selection_metric}"
/ data_group
)
if not prediction_dir.is_dir():
raise MAPSError(
f"No prediction corresponding to data group {data_group} was found."
)
df = pd.read_csv(
prediction_dir / f"{data_group}_{mode}_level_prediction.tsv",
sep="\t",
)
df.set_index(["participant_id", "session_id"], inplace=True, drop=True)
return df

def get_metrics(
self,
data_group: str,
split: int = 0,
selection_metric: Optional[str] = None,
mode: str = "image",
verbose: bool = True,
):
"""
Get the metrics corresponding to a group of participants identified by its data_group.

Args:
data_group (str): name of the data group used for the prediction task.
split (int): Index of the split used for training.
selection_metric (str): Metric used for best weights selection.
mode (str): level of the prediction
verbose (bool): if True will print associated prediction.log
Returns:
(dict[str:float]): Values of the metrics
"""
selection_metric = self._check_selection_metric(split, selection_metric)
if verbose:
self._print_description_log(data_group, split, selection_metric)
prediction_dir = (
self.maps_path
/ f"{self.split_name}-{split}"
/ f"best-{selection_metric}"
/ data_group
)
if not prediction_dir.is_dir():
raise MAPSError(
f"No prediction corresponding to data group {data_group} was found."
)
df = pd.read_csv(
prediction_dir / f"{data_group}_{mode}_level_metrics.tsv", sep="\t"
)
return df.to_dict("records")[0]

@property
def std_amp(self) -> bool:
"""
Expand Down
File renamed without changes.
88 changes: 88 additions & 0 deletions clinicadl/metrics/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from pathlib import Path
from typing import List, Optional

import pandas as pd

from clinicadl.splitter.split_utils import print_description_log
from clinicadl.utils.exceptions import ClinicaDLArgumentError, MAPSError


def find_selection_metrics(maps_path: Path, split_name: str, split):
"""Find which selection metrics are available in MAPS for a given split."""

split_path = maps_path / f"{split_name}-{split}"
if not split_path.is_dir():
raise KeyError(
f"Training of split {split} was not performed."
f"Please execute maps_manager.train(split_list=[{split}])"
)

return [
metric.name.split("-")[1]
for metric in list(split_path.iterdir())
if metric.name.startswith("best-")
]


def check_selection_metric(
maps_path: Path, split_name: str, split, selection_metric=None
):
"""Check that a given selection metric is available for a given split."""
available_metrics = find_selection_metrics(maps_path, split_name, split)

if not selection_metric:
if len(available_metrics) > 1:
raise ClinicaDLArgumentError(
f"Several metrics are available for split {split}. "
f"Please choose which one you want to read among {available_metrics}"
)
else:
selection_metric = available_metrics[0]
else:
if selection_metric not in available_metrics:
raise ClinicaDLArgumentError(
f"The metric {selection_metric} is not available."
f"Please choose among is the available metrics {available_metrics}."
)
return selection_metric


def get_metrics(
maps_path: Path,
split_name: str,
data_group: str,
split: int = 0,
selection_metric: Optional[str] = None,
mode: str = "image",
verbose: bool = True,
):
"""
Get the metrics corresponding to a group of participants identified by its data_group.

Args:
data_group (str): name of the data group used for the prediction task.
split (int): Index of the split used for training.
selection_metric (str): Metric used for best weights selection.
mode (str): level of the prediction
verbose (bool): if True will print associated prediction.log
Returns:
(dict[str:float]): Values of the metrics
"""
selection_metric = check_selection_metric(
maps_path, split_name, split, selection_metric
)
if verbose:
print_description_log(
maps_path, split_name, data_group, split, selection_metric
)
prediction_dir = (
maps_path / f"{split_name}-{split}" / f"best-{selection_metric}" / data_group
)
if not prediction_dir.is_dir():
raise MAPSError(
f"No prediction corresponding to data group {data_group} was found."
)
df = pd.read_csv(
prediction_dir / f"{data_group}_{mode}_level_metrics.tsv", sep="\t"
)
return df.to_dict("records")[0]
Loading
Loading