Skip to content

Commit

Permalink
deprecated GradScaler and autocast
Browse files Browse the repository at this point in the history
  • Loading branch information
camillebrianceau committed Sep 4, 2024
1 parent b04270d commit 8a26ed0
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 12 deletions.
6 changes: 3 additions & 3 deletions clinicadl/interpret/gradients.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import abc

import torch
from torch.cuda.amp import autocast
from torch.amp import autocast

from clinicadl.utils.exceptions import ClinicaDLArgumentError

Expand All @@ -28,7 +28,7 @@ def generate_gradients(
# Forward
input_batch = input_batch.to(self.device)
input_batch.requires_grad = True
with autocast(enabled=amp):
with autocast("cuda", enabled=amp):
if hasattr(self.model, "variational") and self.model.variational:
_, _, _, model_output = self.model(input_batch)
else:
Expand Down Expand Up @@ -94,7 +94,7 @@ def generate_gradients(
# Get last conv feature map
feature_maps = conv_part(input_batch).detach()
feature_maps.requires_grad = True
with autocast(enabled=amp):
with autocast("cuda", enabled=amp):
model_output = fc_part(pre_fc_part(feature_maps))
# Target for backprop
one_hot_output = torch.zeros_like(model_output)
Expand Down
4 changes: 2 additions & 2 deletions clinicadl/maps_manager/maps_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pandas as pd
import torch
import torch.distributed as dist
from torch.cuda.amp import autocast
from torch.amp import autocast

from clinicadl.caps_dataset.caps_dataset_utils import read_json
from clinicadl.caps_dataset.data import (
Expand Down Expand Up @@ -350,7 +350,7 @@ def _compute_output_tensors(
data = dataset[i]
image = data["image"]
x = image.unsqueeze(0).to(model.device)
with autocast(enabled=self.std_amp):
with autocast("cuda", enabled=self.std_amp):
output = model(x)
output = output.squeeze(0).cpu().float()
participant_id = data["participant_id"]
Expand Down
6 changes: 3 additions & 3 deletions clinicadl/predict/predict_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pandas as pd
import torch
import torch.distributed as dist
from torch.cuda.amp import autocast
from torch.amp import autocast
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

Expand Down Expand Up @@ -509,7 +509,7 @@ def _compute_latent_tensors(
data = dataset[i]
image = data["image"]
logger.debug(f"Image for latent representation {image}")
with autocast(enabled=self.maps_manager.std_amp):
with autocast("cuda", enabled=self.maps_manager.std_amp):
_, latent, _ = model.module._forward(
image.unsqueeze(0).to(model.device)
)
Expand Down Expand Up @@ -583,7 +583,7 @@ def _compute_output_nifti(
data = dataset[i]
image = data["image"]
x = image.unsqueeze(0).to(model.device)
with autocast(enabled=self.maps_manager.std_amp):
with autocast("cuda", enabled=self.maps_manager.std_amp):
output = model(x)
output = output.squeeze(0).detach().cpu().float()
# Convert tensor to nifti image with appropriate affine
Expand Down
4 changes: 2 additions & 2 deletions clinicadl/quality_check/t1_linear/quality_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import pandas as pd
import torch
from torch.cuda.amp import autocast
from torch.amp import autocast
from torch.utils.data import DataLoader

from clinicadl.caps_dataset.caps_dataset_config import CapsDatasetConfig
Expand Down Expand Up @@ -153,7 +153,7 @@ def quality_check(
inputs = data["image"]
if computational_config.gpu:
inputs = inputs.cuda()
with autocast(enabled=computational_config.amp):
with autocast("cuda", enabled=computational_config.amp):
outputs = softmax(model(inputs))
# We cast back to 32bits. It should be a no-op as softmax is not eligible
# to fp16 and autocast is forbidden on CPU (output would be bf16 otherwise).
Expand Down
4 changes: 2 additions & 2 deletions clinicadl/trainer/tasks_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
model_validator,
)
from torch import Tensor, nn
from torch.cuda.amp import autocast
from torch.amp import autocast
from torch.nn.functional import softmax
from torch.nn.modules.loss import _Loss
from torch.utils.data import DataLoader, Sampler, sampler
Expand Down Expand Up @@ -240,7 +240,7 @@ def test(
with torch.no_grad():
for i, data in enumerate(dataloader):
# initialize the loss list to save the loss components
with autocast(enabled=amp):
with autocast("cuda", enabled=amp):
outputs, loss_dict = model(data, criterion, use_labels=use_labels)

if i == 0:
Expand Down

0 comments on commit 8a26ed0

Please sign in to comment.