Skip to content

Commit

Permalink
Cleaning (#689)
Browse files Browse the repository at this point in the history
  • Loading branch information
camillebrianceau authored Dec 18, 2024
1 parent 4d40ba6 commit dc3443e
Show file tree
Hide file tree
Showing 129 changed files with 936 additions and 2,870 deletions.
88 changes: 25 additions & 63 deletions clinicadl/API/complicated_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,41 +2,26 @@

import torchio.transforms as transforms

from clinicadl.dataset.dataloader_config import DataLoaderConfig
from clinicadl.dataset.datasets.caps_dataset import CapsDataset
from clinicadl.dataset.datasets.concat import ConcatDataset
from clinicadl.dataset.preprocessing import (
PreprocessingCustom,
from clinicadl.data.dataloader import DataLoaderConfig
from clinicadl.data.datasets.caps_dataset import CapsDataset
from clinicadl.data.datasets.concat import ConcatDataset
from clinicadl.data.preprocessing import (
PreprocessingPET,
PreprocessingT1,
)
from clinicadl.dataset.readers.caps_reader import CapsReader
from clinicadl.experiment_manager.experiment_manager import ExperimentManager
from clinicadl.losses.config import CrossEntropyLossConfig
from clinicadl.losses.factory import get_loss_function
from clinicadl.model.clinicadl_model import ClinicaDLModel
from clinicadl.networks.config import ImplementedNetworks
from clinicadl.networks.factory import (
ConvEncoderOptions,
create_network_config,
get_network_from_config,
)
from clinicadl.optimization.optimizer.config import AdamConfig, OptimizerConfig
from clinicadl.optimization.optimizer.factory import get_optimizer
from clinicadl.predictor.predictor import Predictor
from clinicadl.splitter.kfold import KFolder
from clinicadl.splitter.split import get_single_split, split_tsv
from clinicadl.networks.config.resnet import ResNetConfig
from clinicadl.optim.optimizers.config import AdamConfig
from clinicadl.splitter import KFold, make_kfold, make_split
from clinicadl.trainer.trainer import Trainer
from clinicadl.transforms.extraction import ROI, BaseExtraction, Image, Patch, Slice
from clinicadl.transforms.extraction import Extraction, Image, Patch, Slice
from clinicadl.transforms.transforms import Transforms

# Create the Maps Manager / Read/write manager /
maps_path = Path("/")
manager = ExperimentManager(
maps_path, overwrite=False
) # a ajouter dans le manager: mlflow/ profiler/ etc ...

caps_directory = Path("caps_directory") # output of clinica pipelines
caps_directory = Path(
"/Users/camille.brianceau/aramis/CLINICADL/caps"
) # output of clinica pipelines

sub_ses_t1 = Path("/Users/camille.brianceau/aramis/CLINICADL/caps/subjects_t1.tsv")
preprocessing_t1 = PreprocessingT1()
Expand All @@ -60,7 +45,7 @@
sub_ses_pet_45 = Path(
"/Users/camille.brianceau/aramis/CLINICADL/caps/subjects_pet_18FAV45.tsv"
)
preprocessing_pet_45 = PreprocessingPET(tracer="18FAV45", suvr_reference_region="pons2")
preprocessing_pet_45 = PreprocessingPET(tracer="18FAV45", suvr_reference_region="pons2") # type: ignore

dataset_pet_image = CapsDataset(
caps_directory=caps_directory,
Expand All @@ -79,47 +64,27 @@
) # 3 train.tsv en entrée qu'il faut concat et pareil pour les transforms à faire attention


config_file = Path("config_file")
trainer = Trainer.from_json(config_file=config_file, manager=manager)

# CAS CROSS-VALIDATION
splitter = KFolder(caps_dataset=dataset_multi_modality_multi_extract, manager=manager)
split_dir = splitter.make_splits(
n_splits=3, output_dir=Path(""), subset_name="validation", stratification=""
) # Optional data tsv and output_dir

dataloader_config = DataLoaderConfig(n_procs=3, batch_size=10)
split_dir = make_split(sub_ses_t1, n_test=0.2) # Optional data tsv and output_dir
fold_dir = make_kfold(split_dir / "train.tsv", n_splits=2)

splitter = KFold(fold_dir)

# CAS 1

# Prérequis : déjà avoir des fichiers avec les listes train et validation
split_dir = make_kfold(
"dataset.tsv"
) # lit dataset.tsv => fait le kfold => ecrit la sortie dans split_dir
splitter = KFolder(
dataset_multi_modality, split_dir
) # c'est plutôt un iterable de dataloader
maps_path = Path("/")
manager = ExperimentManager(maps_path, overwrite=False)

# CAS 2
splitter = KFolder(caps_dataset=dataset_t1_image)
splitter.make_splits(n_splits=3)
splitter.write(split_dir)
config_file = Path("config_file")
trainer = Trainer.from_json(config_file=config_file, manager=manager)

# or
splitter = KFolder(caps_dataset=dataset_t1_image)
splitter.read(split_dir)

for split in splitter.get_splits(splits=(0, 3, 4), dataloader_config=dataloader_config):
# bien définir ce qu'il y a dans l'objet split
for split in splitter.get_splits(dataset=dataset_t1_image):
train_loader = split.build_train_loader(batch_size=2)
val_loader = split.build_val_loader(DataLoaderConfig())

network_config = create_network_config(ImplementedNetworks.CNN)(
in_shape=[2, 2, 2],
num_outputs=1,
conv_args=ConvEncoderOptions(channels=[3, 2, 2]),
)
model = ClinicaDLModelClassif.from_config(
network_config=network_config,
model = ClinicaDLModel.from_config(
network_config=ResNetConfig(num_outputs=1, spatial_dims=1, in_channels=1),
loss_config=CrossEntropyLossConfig(),
optimizer_config=AdamConfig(),
)
Expand All @@ -133,9 +98,6 @@
dataset_test = CapsDataset(
caps_directory=caps_directory,
preprocessing=preprocessing_t1,
sub_ses_tsv=Path("test.tsv"), # test only on data from the first dataset
data=Path("test.tsv"), # test only on data from the first dataset
transforms=transforms_image,
)

predictor = Predictor(model=model, manager=manager)
predictor.predict(dataset_test=dataset_test, split_number=2)
27 changes: 12 additions & 15 deletions clinicadl/API/cross_val.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from pathlib import Path

from clinicadl.dataset.datasets.caps_dataset import CapsDataset
from clinicadl.experiment_manager.experiment_manager import ExperimentManager
from clinicadl.predictor.predictor import Predictor
from clinicadl.splitter.new_splitter.dataloader import DataLoaderConfig
from clinicadl.splitter.new_splitter.splitter.kfold import KFold
from clinicadl.trainer.trainer import Trainer
from clinicadl.data.dataloader import DataLoaderConfig
from clinicadl.data.datasets import CapsDataset
from clinicadl.experiment_manager import ExperimentManager
from clinicadl.splitter import KFold, make_kfold, make_split
from clinicadl.trainer import Trainer

# SIMPLE EXPERIMENT WITH A CAPS ALREADY EXISTING

Expand All @@ -19,20 +18,18 @@
config_file=config_file, manager=manager
) # gpu, amp, fsdp, seed

splitter = KFold(dataset=dataset_t1_image)
splitter.make_splits(n_splits=3)
split_dir = Path("")
splitter.write(split_dir)
split_dir = make_split(
dataset_t1_image.df, n_test=0.2, subset_name="validation", output_dir="test"
) # Optional data tsv and output_dir
fold_dir = make_kfold(split_dir / "train.tsv", n_splits=2)

splitter = KFold(fold_dir)

splitter.read(split_dir)

# define the needed parameters for the dataloader
dataloader_config = DataLoaderConfig(num_workers=3, batch_size=10)


for split in splitter.get_splits(splits=(0, 3, 4)):
print(split)
for split in splitter.get_splits(dataset=dataset_t1_image):
split.build_train_loader(dataloader_config)
split.build_val_loader(num_workers=3, batch_size=10)

print(split)
61 changes: 26 additions & 35 deletions clinicadl/API/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,25 @@

import torchio.transforms as transforms

from clinicadl.dataset.datasets.caps_dataset import CapsDataset
from clinicadl.dataset.datasets.concat import ConcatDataset
from clinicadl.dataset.preprocessing import (
from clinicadl.data.datasets import CapsDataset, ConcatDataset
from clinicadl.data.preprocessing import (
BasePreprocessing,
PreprocessingFlair,
PreprocessingPET,
PreprocessingT1,
)
from clinicadl.experiment_manager.experiment_manager import ExperimentManager
from clinicadl.data.preprocessing.pet import SUVRReferenceRegions, Tracer
from clinicadl.experiment_manager import ExperimentManager
from clinicadl.losses.config import CrossEntropyLossConfig
from clinicadl.model.clinicadl_model import ClinicaDLModel
from clinicadl.networks.factory import (
ConvEncoderOptions,
create_network_config,
get_network_from_config,
)
from clinicadl.splitter.kfold import KFolder
from clinicadl.splitter.split import get_single_split, split_tsv
from clinicadl.transforms.extraction import ROI, Image, Patch, Slice
from clinicadl.transforms.transforms import Transforms
from clinicadl.splitter import KFold, make_kfold, make_split
from clinicadl.transforms import Transforms
from clinicadl.transforms.extraction import Image, Patch, Slice

sub_ses_t1 = Path("/Users/camille.brianceau/aramis/CLINICADL/caps/subjects_t1.tsv")
sub_ses_pet_45 = Path(
Expand All @@ -38,8 +37,12 @@
"/Users/camille.brianceau/aramis/CLINICADL/caps"
) # output of clinica pipelines

preprocessing_pet_45 = PreprocessingPET(tracer="18FAV45", suvr_reference_region="pons2")
preprocessing_pet_11 = PreprocessingPET(tracer="11CPIB", suvr_reference_region="pons2")
preprocessing_pet_45 = PreprocessingPET(
tracer=Tracer.FAV45, suvr_reference_region=SUVRReferenceRegions.PONS2
)
preprocessing_pet_11 = PreprocessingPET(
tracer=Tracer.CPIB, suvr_reference_region=SUVRReferenceRegions.PONS2
)

preprocessing_t1 = PreprocessingT1()
preprocessing_flair = PreprocessingFlair()
Expand All @@ -55,18 +58,6 @@

transforms_slice = Transforms(extraction=Slice())

transforms_roi = Transforms(
object_augmentation=[transforms.Ghosting(2, 1, 0.1, 0.1)],
object_transforms=[transforms.RandomMotion()],
extraction=ROI(
roi_list=["leftHippocampusBox", "rightHippocampusBox"],
roi_mask_location=Path(
"/Users/camille.brianceau/aramis/CLINICADL/caps/masks/tpl-MNI152NLin2009cSym"
),
roi_crop_input=True,
),
)

transforms_image = Transforms(
image_augmentation=[transforms.RandomMotion()],
extraction=Image(),
Expand Down Expand Up @@ -96,25 +87,25 @@
)


print("Pet 11 and ROI ")
print("Pet 11 and Image ")

dataset_pet_11_roi = CapsDataset(
dataset_pet_11_image = CapsDataset(
caps_directory=caps_directory,
data=sub_ses_pet_11,
preprocessing=preprocessing_pet_11,
transforms=transforms_roi,
transforms=transforms_image,
)
dataset_pet_11_roi.prepare_data(
dataset_pet_11_image.prepare_data(
n_proc=2
) # to extract the tensor of the PET file this time

print(dataset_pet_11_roi)
print(dataset_pet_11_roi.__len__())
print(dataset_pet_11_roi._get_meta_data(0))
print(dataset_pet_11_roi._get_meta_data(1))
# print(dataset_pet_11_roi._get_full_image())
print(dataset_pet_11_roi.__getitem__(1).elem_idx)
print(dataset_pet_11_roi.elem_per_image)
print(dataset_pet_11_image)
print(dataset_pet_11_image.__len__())
print(dataset_pet_11_image._get_meta_data(0))
print(dataset_pet_11_image._get_meta_data(1))
# print(dataset_pet_11_image._get_full_image())
print(dataset_pet_11_image.__getitem__(1).elem_idx)
print(dataset_pet_11_image.elem_per_image)


print("T1 and image ")
Expand Down Expand Up @@ -161,7 +152,7 @@

lity_multi_extract = ConcatDataset(
[
dataset_t1,
dataset_pet,
dataset_t1_image,
dataset_pet_11_image,
]
) # 3 train.tsv en entrée qu'il faut concat et pareil pour les transforms à faire attention
Loading

0 comments on commit dc3443e

Please sign in to comment.