Skip to content

Commit

Permalink
remove split_name parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
camillebrianceau committed Oct 3, 2024
1 parent 7382e97 commit 765f29a
Show file tree
Hide file tree
Showing 18 changed files with 69 additions and 152 deletions.
4 changes: 2 additions & 2 deletions clinicadl/commandline/pipelines/predict/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from clinicadl.commandline import arguments
from clinicadl.commandline.modules_options import (
computational,
cross_validation,
data,
dataloader,
maps_manager,
split,
validation,
)
from clinicadl.commandline.pipelines.predict import options
Expand All @@ -29,7 +29,7 @@
@data.diagnoses
@validation.skip_leak_check
@validation.selection_metrics
@cross_validation.split
@split.split
@computational.gpu
@computational.amp
@dataloader.n_proc
Expand Down
6 changes: 3 additions & 3 deletions clinicadl/commandline/pipelines/train/classification/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from clinicadl.commandline.modules_options import (
callbacks,
computational,
cross_validation,
data,
dataloader,
early_stopping,
Expand All @@ -13,6 +12,7 @@
optimization,
optimizer,
reproducibility,
split,
ssda,
transforms,
validation,
Expand Down Expand Up @@ -70,8 +70,8 @@
@ssda.tsv_target_unlab
@ssda.preprocessing_json_target
# Cross validation
@cross_validation.n_splits
@cross_validation.split
@split.n_splits
@split.split
# Optimization
@optimizer.optimizer
@optimizer.weight_decay
Expand Down
4 changes: 2 additions & 2 deletions clinicadl/commandline/pipelines/train/from_json/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@

from clinicadl.commandline import arguments
from clinicadl.commandline.modules_options import (
cross_validation,
split,
)
from clinicadl.trainer.trainer import Trainer


@click.command(name="from_json", no_args_is_help=True)
@arguments.config_file
@arguments.output_maps
@cross_validation.split
@split.split
def cli(**kwargs):
"""
Replicate a deep learning training based on a previously created JSON file.
Expand Down
6 changes: 3 additions & 3 deletions clinicadl/commandline/pipelines/train/reconstruction/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from clinicadl.commandline.modules_options import (
callbacks,
computational,
cross_validation,
data,
dataloader,
early_stopping,
Expand All @@ -13,6 +12,7 @@
optimization,
optimizer,
reproducibility,
split,
ssda,
transforms,
validation,
Expand Down Expand Up @@ -70,8 +70,8 @@
@ssda.tsv_target_unlab
@ssda.preprocessing_json_target
# Cross validation
@cross_validation.n_splits
@cross_validation.split
@split.n_splits
@split.split
# Optimization
@optimizer.optimizer
@optimizer.weight_decay
Expand Down
6 changes: 3 additions & 3 deletions clinicadl/commandline/pipelines/train/regression/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from clinicadl.commandline.modules_options import (
callbacks,
computational,
cross_validation,
data,
dataloader,
early_stopping,
Expand All @@ -13,6 +12,7 @@
optimization,
optimizer,
reproducibility,
split,
ssda,
transforms,
validation,
Expand Down Expand Up @@ -68,8 +68,8 @@
@ssda.tsv_target_unlab
@ssda.preprocessing_json_target
# Cross validation
@cross_validation.n_splits
@cross_validation.split
@split.n_splits
@split.split
# Optimization
@optimizer.optimizer
@optimizer.weight_decay
Expand Down
4 changes: 2 additions & 2 deletions clinicadl/commandline/pipelines/train/resume/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

from clinicadl.commandline import arguments
from clinicadl.commandline.modules_options import (
cross_validation,
split,
)
from clinicadl.trainer.trainer import Trainer


@click.command(name="resume", no_args_is_help=True)
@arguments.input_maps
@cross_validation.split
@split.split
def cli(input_maps_directory, split):
"""Resume training job in specified maps.
Expand Down
47 changes: 9 additions & 38 deletions clinicadl/maps_manager/maps_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,17 +105,11 @@ def __init__(
f"Please choose between classification, regression and reconstruction."
)

self.split_name = (
self._check_split_wording()
) # Used only for retro-compatibility

# Initiate MAPS
else:
print(parameters)
self._check_args(parameters)
parameters["tsv_path"] = Path(parameters["tsv_path"])

self.split_name = "split" # Used only for retro-compatibility
if cluster.master:
if (maps_path.is_dir() and maps_path.is_file()) or ( # Non-folder file
maps_path.is_dir() and list(maps_path.iterdir()) # Non empty folder
Expand Down Expand Up @@ -326,12 +320,7 @@ def _write_train_val_groups(self):
for split in split_manager.split_iterator():
for data_group in ["train", "validation"]:
df = split_manager[split][data_group]
group_path = (
self.maps_path
/ "groups"
/ data_group
/ f"{self.split_name}-{split}"
)
group_path = self.maps_path / "groups" / data_group / f"split-{split}"
group_path.mkdir(parents=True, exist_ok=True)

columns = ["participant_id", "session_id", "cohort"]
Expand Down Expand Up @@ -422,10 +411,7 @@ def _mode_level_to_tsv(
data_group: the name referring to the data group on which evaluation is performed.
"""
performance_dir = (
self.maps_path
/ f"{self.split_name}-{split}"
/ f"best-{selection}"
/ data_group
self.maps_path / f"split-{split}" / f"best-{selection}" / data_group
)
performance_dir.mkdir(parents=True, exist_ok=True)
performance_path = (
Expand Down Expand Up @@ -482,7 +468,6 @@ def _ensemble_to_tsv(
validation_dataset = "validation"
test_df = get_prediction(
self.maps_path,
self.split_name,
data_group,
split,
selection,
Expand All @@ -491,7 +476,6 @@ def _ensemble_to_tsv(
)
validation_df = get_prediction(
self.maps_path,
self.split_name,
validation_dataset,
split,
selection,
Expand All @@ -500,10 +484,7 @@ def _ensemble_to_tsv(
)

performance_dir = (
self.maps_path
/ f"{self.split_name}-{split}"
/ f"best-{selection}"
/ data_group
self.maps_path / f"split-{split}" / f"best-{selection}" / data_group
)

performance_dir.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -551,7 +532,6 @@ def _mode_to_image_tsv(
"""
sub_df = get_prediction(
self.maps_path,
self.split_name,
data_group,
split,
selection,
Expand All @@ -561,10 +541,7 @@ def _mode_to_image_tsv(
sub_df.rename(columns={f"{self.mode}_id": "image_id"}, inplace=True)

performance_dir = (
self.maps_path
/ f"{self.split_name}-{split}"
/ f"best-{selection}"
/ data_group
self.maps_path / f"split-{split}" / f"best-{selection}" / data_group
)
sub_df.to_csv(
performance_dir / f"{data_group}_image_level_prediction.tsv",
Expand Down Expand Up @@ -638,10 +615,7 @@ def _init_model(

if resume:
checkpoint_path = (
self.maps_path
/ f"{self.split_name}-{split}"
/ "tmp"
/ "checkpoint.pth.tar"
self.maps_path / f"split-{split}" / "tmp" / "checkpoint.pth.tar"
)
checkpoint_state = torch.load(
checkpoint_path, map_location=device, weights_only=True
Expand Down Expand Up @@ -694,10 +668,7 @@ def _print_description_log(
selection_metric (str): Metric used for best weights selection.
"""
log_dir = (
self.maps_path
/ f"{self.split_name}-{split}"
/ f"best-{selection_metric}"
/ data_group
self.maps_path / f"split-{split}" / f"best-{selection_metric}" / data_group
)
log_path = log_dir / "description.log"
with log_path.open(mode="r") as f:
Expand Down Expand Up @@ -767,7 +738,7 @@ def get_state_dict(
(Dict): dictionary of results (weights, epoch number, metrics values)
"""
selection_metric = check_selection_metric(
self.maps_path, self.split_name, split, selection_metric
self.maps_path, split, selection_metric
)
if self.multi_network:
if network is None:
Expand All @@ -777,14 +748,14 @@ def get_state_dict(
else:
model_path = (
self.maps_path
/ f"{self.split_name}-{split}"
/ f"split-{split}"
/ f"best-{selection_metric}"
/ f"network-{network}_model.pth.tar"
)
else:
model_path = (
self.maps_path
/ f"{self.split_name}-{split}"
/ f"split-{split}"
/ f"best-{selection_metric}"
/ "model.pth.tar"
)
Expand Down
21 changes: 7 additions & 14 deletions clinicadl/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from clinicadl.utils.exceptions import ClinicaDLArgumentError, MAPSError


def find_selection_metrics(maps_path: Path, split_name: str, split):
def find_selection_metrics(maps_path: Path, split):
"""Find which selection metrics are available in MAPS for a given split."""

split_path = maps_path / f"{split_name}-{split}"
split_path = maps_path / f"split-{split}"
if not split_path.is_dir():
raise KeyError(
f"Training of split {split} was not performed."
Expand All @@ -24,11 +24,9 @@ def find_selection_metrics(maps_path: Path, split_name: str, split):
]


def check_selection_metric(
maps_path: Path, split_name: str, split, selection_metric=None
):
def check_selection_metric(maps_path: Path, split, selection_metric=None):
"""Check that a given selection metric is available for a given split."""
available_metrics = find_selection_metrics(maps_path, split_name, split)
available_metrics = find_selection_metrics(maps_path, split)

if not selection_metric:
if len(available_metrics) > 1:
Expand All @@ -49,7 +47,6 @@ def check_selection_metric(

def get_metrics(
maps_path: Path,
split_name: str,
data_group: str,
split: int = 0,
selection_metric: Optional[str] = None,
Expand All @@ -68,15 +65,11 @@ def get_metrics(
Returns:
(dict[str:float]): Values of the metrics
"""
selection_metric = check_selection_metric(
maps_path, split_name, split, selection_metric
)
selection_metric = check_selection_metric(maps_path, split, selection_metric)
if verbose:
print_description_log(
maps_path, split_name, data_group, split, selection_metric
)
print_description_log(maps_path, data_group, split, selection_metric)
prediction_dir = (
maps_path / f"{split_name}-{split}" / f"best-{selection_metric}" / data_group
maps_path / f"split-{split}" / f"best-{selection_metric}" / data_group
)
if not prediction_dir.is_dir():
raise MAPSError(
Expand Down
Loading

0 comments on commit 765f29a

Please sign in to comment.