Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
camillebrianceau committed Jul 22, 2024
1 parent b1df7f9 commit d14b696
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 63 deletions.
3 changes: 3 additions & 0 deletions clinicadl/caps_dataset/caps_dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union

import pandas as pd
from pydantic import BaseModel, ConfigDict

from clinicadl.caps_dataset.data_config import DataConfig
Expand Down Expand Up @@ -136,6 +137,8 @@ def from_data_group(
config.data.caps_directory = caps_directory
config.data.data_tsv = data_tsv

config.data.data_df = pd.read_csv(config.data.data_tsv, sep="\t")

return config

@classmethod
Expand Down
4 changes: 3 additions & 1 deletion clinicadl/commandline/pipelines/predict/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,10 @@ def cli(input_maps_directory, data_group, **kwargs):
INPUT_MAPS_DIRECTORY is the MAPS folder from where the model used for prediction will be loaded.
DATA_GROUP is the name of the subjects and sessions list used for the interpretation.
"""
predictor = Predictor.from_maps(input_maps_directory)
print(kwargs["gpu"])
predictor = Predictor.from_maps(input_maps_directory, **kwargs)
print(predictor)

caps_config = CapsDatasetConfig.from_data_group(
input_maps_directory, data_group, **kwargs
)
Expand Down
127 changes: 65 additions & 62 deletions clinicadl/predict/predict_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def __init__(self, _config: TrainConfig) -> None:
self.maps_manager = MapsManager(maps_path, parameters, verbose=None)

@classmethod
def from_json(cls, config_file: Union[str, Path], maps_path: Union[str, Path]):
def from_json(
cls, config_file: Union[str, Path], maps_path: Union[str, Path], **kwargs
):
"""
Creates a Trainer from a json configuration file.
Expand All @@ -82,13 +84,14 @@ def from_json(cls, config_file: Union[str, Path], maps_path: Union[str, Path]):
raise FileNotFoundError(f"No file found at {str(config_file)}.")
config_dict = patch_to_read_json(read_json(config_file)) # TODO : remove patch
config_dict["maps_dir"] = maps_path
config_dict.update(kwargs)
config_object = create_training_config(config_dict["network_task"])(
**config_dict
)
return cls(config_object)

@classmethod
def from_maps(cls, maps_path: Union[str, Path]):
def from_maps(cls, maps_path: Union[str, Path], **kwargs):
"""
Creates a Trainer from a json configuration file.
Expand All @@ -114,7 +117,7 @@ def from_maps(cls, maps_path: Union[str, Path]):
f"MAPS was not found at {str(maps_path)}."
f"To initiate a new MAPS please give a train_dict."
)
return cls.from_json(maps_path / "maps.json", maps_path)
return cls.from_json(maps_path / "maps.json", maps_path, **kwargs)

def predict(
self,
Expand Down Expand Up @@ -253,19 +256,22 @@ def predict(
)

self._check_data_group(data_group, caps_config)
print(caps_config)
criterion = self.maps_manager.task_manager.get_criterion(self.maps_manager.loss)
self._check_data_group(df=group_df)
# self._check_data_group(df=group_df)

assert self.config.split # don't know if needed ? try to raise an exception ?
assert (
self.config.cross_validation.split
) # don't know if needed ? try to raise an exception ?
# assert self.config.data.label

for split in self.config.split:
for split in self.config.cross_validation.split:
logger.info(f"Prediction of split {split}")
group_df, group_parameters = self.get_group_info(
self.config.data.data_group, split
)
group_df, group_parameters = self.get_group_info(data_group, split)
# Find label code if not given
if self.config.is_given_label_code(self.maps_manager.label, label_code):
if self.config.data.is_given_label_code(
self.maps_manager.label, label_code
):
self.maps_manager.task_manager.generate_label_code(
group_df, self.config.data.label
)
Expand All @@ -281,12 +287,12 @@ def predict(
self.maps_manager.maps_path
/ f"{self.maps_manager.split_name}-{split}"
/ f"best-{selection}"
/ self.config.data.data_group
/ data_group
)
tsv_pattern = f"{self.config.data.data_group}*.tsv"
tsv_pattern = f"{data_group}*.tsv"
for tsv_file in tsv_dir.glob(tsv_pattern):
tsv_file.unlink()
self.config.check_label(self.maps_manager.label)
self.config.data.check_label(self.maps_manager.label)
if self.maps_manager.multi_network:
self._predict_multi(
group_parameters,
Expand All @@ -309,11 +315,11 @@ def predict(
)
if cluster.master:
self.maps_manager._ensemble_prediction(
self.config.data.data_group,
data_group,
split,
self.config.validation.selection_metrics,
self.config.data.use_labels,
self.config.skip_leak_check,
self.config.validation.skip_leak_check,
)

def _predict_multi(
Expand Down Expand Up @@ -954,60 +960,57 @@ def _check_data_group(
group_dir = self.maps_manager.maps_path / "groups" / data_group
logger.debug(f"Group path {group_dir}")
if group_dir.is_dir(): # Data group already exists
if self.config.maps_manager.overwrite:
if data_group in ["train", "validation"]:
raise MAPSError("Cannot overwrite train or validation data group.")
else:
print(self.config.cross_validation.split)
if not self.config.cross_validation.split:
self.config.cross_validation.split = (
self.maps_manager.find_splits()
)

print(self.config.cross_validation.split)
# assert self.config.split
for split in self.config.cross_validation.split:
selection_metrics = self.maps_manager._find_selection_metrics(
split
)
for selection in selection_metrics:
results_path = (
self.maps_manager.maps_path
/ f"{self.maps_manager.split_name}-{split}"
/ f"best-{selection}"
/ data_group
)
if results_path.is_dir():
shutil.rmtree(results_path)
elif df is not None or (
caps_config.caps_directory is not None
and self.config.caps_directory != Path("")
):
raise ClinicaDLArgumentError(
f"Data group {data_group} is already defined. "
f"Please do not give any caps_directory, tsv_path or multi_cohort to use it. "
f"To erase {data_group} please set overwrite to True."
)

elif not group_dir.is_dir() and (
self.config.caps_directory is None or df is None
): # Data group does not exist yet / was overwritten + missing data
raise ClinicaDLArgumentError(
f"The data group {self.config.data.data_group} does not already exist. "
f"Please specify a caps_directory and a tsv_path to create this data group."
)
# if self.config.maps_manager.overwrite:
# if data_group in ["train", "validation"]:
# raise MAPSError("Cannot overwrite train or validation data group.")
# else:
print("cross validation_split", self.config.cross_validation.split)
if not self.config.cross_validation.split:
self.config.cross_validation.split = self.maps_manager.find_splits()
print("cross validation_split", self.config.cross_validation.split)
# assert self.config.split

# IF there is a dir
for split in self.config.cross_validation.split:
selection_metrics = self.maps_manager._find_selection_metrics(split)
for selection in selection_metrics:
results_path = (
self.maps_manager.maps_path
/ f"{self.maps_manager.split_name}-{split}"
/ f"best-{selection}"
/ data_group
)
if results_path.is_dir() and self.config.maps_manager.overwrite:
shutil.rmtree(results_path)
# elif df is not None or (
# caps_config.caps_directory is not None
# and self.config.caps_directory != Path("")
# ):
# raise ClinicaDLArgumentError(
# f"Data group {data_group} is already defined. "
# f"Please do not give any caps_directory, tsv_path or multi_cohort to use it. "
# f"To erase {data_group} please set overwrite to True."
# )

# elif not group_dir.is_dir() and (
# self.config.caps_directory is None or df is None
# ): # Data group does not exist yet / was overwritten + missing data
# raise ClinicaDLArgumentError(
# f"The data group {self.config.data.data_group} does not already exist. "
# f"Please specify a caps_directory and a tsv_path to create this data group."
# )
elif (
not group_dir.is_dir()
): # Data group does not exist yet / was overwritten + all data is provided
if self.config.skip_leak_check:
if self.config.validation.skip_leak_check:
logger.info("Skipping data leakage check")
else:
self._check_leakage(self.config.data.data_group, df)
self._check_leakage(data_group, caps_config.data.data_df)
self._write_data_group(
self.config.data.data_group,
data_group,
df,
self.config.caps_directory,
self.config.multi_cohort,
caps_config.data.caps_directory,
caps_config.data.multi_cohort,
label=self.config.data.label,
)

Expand Down
6 changes: 6 additions & 0 deletions clinicadl/trainer/config/reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ class NetworkConfig(BaseNetworkConfig): # TODO : put in model module
def validator_architecture(cls, v):
return v # TODO : connect to network module to have list of available architectures

@field_validator("normalization", mode="before")
def validator_normalization(cls, v):
if v == "batch":
v = "BatchNorm"
return v


class ValidationConfig(BaseValidationConfig):
"""Config class for the validation procedure in reconstruction mode."""
Expand Down

0 comments on commit d14b696

Please sign in to comment.