Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

clinicadl-pythae #459

Draft
wants to merge 3 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions clinicadl/resources/config/train_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
[Model]
architecture = "default" # ex : Conv5_FC3
multi_network = false
pythae = false

[Architecture]
# CNN
Expand All @@ -14,6 +15,11 @@ latent_space_size = 128
feature_size = 1024
n_conv = 4
io_layer_channels = 8
first_layer_channels = 32
n_conv_encoder = 4
n_conv_decoder = 4
last_layer_channels = 32
last_layer_conv = false
recons_weight = 1
kl_weight = 1
normalization = "batch"
Expand Down
5 changes: 4 additions & 1 deletion clinicadl/utils/caps_dataset/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ def __init__(
f"Columns should include {mandatory_col}"
)
self.elem_per_image = self.num_elem_per_image()
self.size = self[0]["image"].size()
if "image" in self[0].keys():
self.size = self[0]["image"].size()
else:
self.size = self[0].data.size()

@property
@abc.abstractmethod
Expand Down
73 changes: 68 additions & 5 deletions clinicadl/utils/maps_manager/maps_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,33 @@
# TODO save weights on CPU for better compatibility


pythae_models_list = [
"pythae_Adversarial_AE",
"pythae_AE",
"pythae_BetaTCVAE",
"pythae_BetaVAE",
"pythae_CIWAE",
"pythae_DisentangledBetaVAEpythae_FactorVAE",
"pythae_HVAE",
"pythae_IWAE",
"pythae_INFOVAE_MMD",
"pythae_MSSSIM_VAE",
"pythae_MIWAE",
"pythae_PIWAE",
"pythae_PoincareVAE",
"pythae_RAE_GP",
"pythae_RAE_L2",
"pythae_RHVAE",
"pythae_SVAE",
"pythae_VAE",
"pythae_VAE_IAF",
"pythae_VAE_LinNF",
"pythae_VAEGAN",
"pythae_VAMP",
"pythae_VQVAE",
"pythae_WAE_MMD",
]

class MapsManager:
def __init__(
self,
Expand Down Expand Up @@ -1127,11 +1154,23 @@ def _compute_output_nifti(
nb_imgs = len(dataset)
for i in range(nb_imgs):
data = dataset[i]
image = data["image"]

try:
image = data["image"]
except:
image = data["data"]
output = (
model.predict(image.unsqueeze(0).to(model.device))
.squeeze(0)
.detach()
.cpu()
)

x = image.unsqueeze(0).to(model.device)
with autocast(enabled=self.amp):
output = model.predict(x)
output = output.squeeze(0).detach().cpu().float()

# Convert tensor to nifti image with appropriate affine
input_nii = nib.Nifti1Image(image[0].detach().cpu().numpy(), eye(4))
output_nii = nib.Nifti1Image(output[0].numpy(), eye(4))
Expand Down Expand Up @@ -1192,11 +1231,19 @@ def _compute_output_tensors(

for i in range(nb_modes):
data = dataset[i]
image = data["image"]
try:
image = data["image"]
except:
image = data["data"]
output = (
model.predict(image.unsqueeze(0).to(model.device)).squeeze(0).cpu()
)

x = image.unsqueeze(0).to(model.device)
with autocast(enabled=self.amp):
output = model.predict(x)
output = output.squeeze(0).cpu().float()

participant_id = data["participant_id"]
session_id = data["session_id"]
mode_id = data[f"{self.mode}_id"]
Expand Down Expand Up @@ -1709,7 +1756,10 @@ def _write_information(self):
"""
from datetime import datetime

import clinicadl.utils.network as network_package
if self.pythae :
import clinithae.network as network_package
else :
import clinicadl.utils.network as network_package

model_class = getattr(network_package, self.architecture)
args = list(
Expand All @@ -1719,6 +1769,9 @@ def _write_information(self):
)
args.remove("self")
kwargs = dict()
print(args)
print("para")
print(self.parameters)
for arg in args:
kwargs[arg] = self.parameters[arg]
kwargs["gpu"] = False
Expand Down Expand Up @@ -1938,7 +1991,10 @@ def _init_model(
gpu (bool): If given, a new value for the device of the model will be computed.
network (int): Index of the network trained (used in multi-network setting only).
"""
import clinicadl.utils.network as network_package
if self.pythae :
import clinithae.network as network_package
else :
import clinicadl.utils.network as network_package

logger.debug(f"Initialization of model {self.architecture}")
# or choose to implement a dictionary
Expand Down Expand Up @@ -1984,7 +2040,14 @@ def _init_model(
)
transfer_class = getattr(network_package, transfer_maps.architecture)
logger.debug(f"Transfer from {transfer_class}")
model.transfer_weights(transfer_state["model"], transfer_class)
if "model" in transfer_state.keys():
model.transfer_weights(transfer_state["model"], transfer_class)
elif "model_state_dict" in transfer_state.keys():
model.transfer_weights(
transfer_state["model_state_dict"], transfer_class
)
else:
raise KeyError("Unknow key in model state dictionnary.")

return model, current_epoch

Expand Down
3 changes: 1 addition & 2 deletions clinicadl/utils/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ def write_preprocessing(preprocessing_dict: Dict[str, Any], caps_directory: Path

def read_preprocessing(json_path: Path) -> Dict[str, Any]:
if not json_path.name.endswith(".json"):
json_path += ".json"
json_path = json_path
json_path = Path(json_path) / ".json"

if not json_path.is_file():
raise FileNotFoundError(errno.ENOENT, json_path)
Expand Down
Loading