diff --git a/clinicadl/interpret/gradients.py b/clinicadl/interpret/gradients.py index 393e64488..b62308f38 100644 --- a/clinicadl/interpret/gradients.py +++ b/clinicadl/interpret/gradients.py @@ -1,7 +1,7 @@ import abc import torch -from torch.cuda.amp import autocast +from torch.amp import autocast from clinicadl.utils.exceptions import ClinicaDLArgumentError @@ -28,7 +28,7 @@ def generate_gradients( # Forward input_batch = input_batch.to(self.device) input_batch.requires_grad = True - with autocast(enabled=amp): + with autocast("cuda", enabled=amp): if hasattr(self.model, "variational") and self.model.variational: _, _, _, model_output = self.model(input_batch) else: @@ -94,7 +94,7 @@ def generate_gradients( # Get last conv feature map feature_maps = conv_part(input_batch).detach() feature_maps.requires_grad = True - with autocast(enabled=amp): + with autocast("cuda", enabled=amp): model_output = fc_part(pre_fc_part(feature_maps)) # Target for backprop one_hot_output = torch.zeros_like(model_output) diff --git a/clinicadl/maps_manager/maps_manager.py b/clinicadl/maps_manager/maps_manager.py index 35f06411c..db77681ff 100644 --- a/clinicadl/maps_manager/maps_manager.py +++ b/clinicadl/maps_manager/maps_manager.py @@ -8,7 +8,7 @@ import pandas as pd import torch import torch.distributed as dist -from torch.cuda.amp import autocast +from torch.amp import autocast from clinicadl.caps_dataset.caps_dataset_utils import read_json from clinicadl.caps_dataset.data import ( @@ -350,7 +350,7 @@ def _compute_output_tensors( data = dataset[i] image = data["image"] x = image.unsqueeze(0).to(model.device) - with autocast(enabled=self.std_amp): + with autocast("cuda", enabled=self.std_amp): output = model(x) output = output.squeeze(0).cpu().float() participant_id = data["participant_id"] diff --git a/clinicadl/predict/predict_manager.py b/clinicadl/predict/predict_manager.py index 3d103ad8a..af0720c4c 100644 --- a/clinicadl/predict/predict_manager.py +++ b/clinicadl/predict/predict_manager.py @@ -7,7 +7,7 @@ import pandas as pd import torch import torch.distributed as dist -from torch.cuda.amp import autocast +from torch.amp import autocast from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler @@ -509,7 +509,7 @@ def _compute_latent_tensors( data = dataset[i] image = data["image"] logger.debug(f"Image for latent representation {image}") - with autocast(enabled=self.maps_manager.std_amp): + with autocast("cuda", enabled=self.maps_manager.std_amp): _, latent, _ = model.module._forward( image.unsqueeze(0).to(model.device) ) @@ -583,7 +583,7 @@ def _compute_output_nifti( data = dataset[i] image = data["image"] x = image.unsqueeze(0).to(model.device) - with autocast(enabled=self.maps_manager.std_amp): + with autocast("cuda", enabled=self.maps_manager.std_amp): output = model(x) output = output.squeeze(0).detach().cpu().float() # Convert tensor to nifti image with appropriate affine diff --git a/clinicadl/quality_check/t1_linear/quality_check.py b/clinicadl/quality_check/t1_linear/quality_check.py index f840a4583..b916d3186 100755 --- a/clinicadl/quality_check/t1_linear/quality_check.py +++ b/clinicadl/quality_check/t1_linear/quality_check.py @@ -8,7 +8,7 @@ import pandas as pd import torch -from torch.cuda.amp import autocast +from torch.amp import autocast from torch.utils.data import DataLoader from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig @@ -153,7 +153,7 @@ def quality_check( inputs = data["image"] if computational_config.gpu: inputs = inputs.cuda() - with autocast(enabled=computational_config.amp): + with autocast("cuda", enabled=computational_config.amp): outputs = softmax(model(inputs)) # We cast back to 32bits. It should be a no-op as softmax is not eligible # to fp16 and autocast is forbidden on CPU (output would be bf16 otherwise). diff --git a/clinicadl/trainer/tasks_utils.py b/clinicadl/trainer/tasks_utils.py index 4b5a012ef..be6d6df11 100644 --- a/clinicadl/trainer/tasks_utils.py +++ b/clinicadl/trainer/tasks_utils.py @@ -12,7 +12,7 @@ model_validator, ) from torch import Tensor, nn -from torch.cuda.amp import autocast +from torch.amp import autocast from torch.nn.functional import softmax from torch.nn.modules.loss import _Loss from torch.utils.data import DataLoader, Sampler, sampler @@ -240,7 +240,7 @@ def test( with torch.no_grad(): for i, data in enumerate(dataloader): # initialize the loss list to save the loss components - with autocast(enabled=amp): + with autocast("cuda", enabled=amp): outputs, loss_dict = model(data, criterion, use_labels=use_labels) if i == 0: