Skip to content

Commit

Permalink
fix for bug Project-MONAI#17 wrong input device
Browse files Browse the repository at this point in the history
Signed-off-by: elitap <[email protected]>
  • Loading branch information
elitap committed Aug 30, 2023
1 parent a79e81d commit 75c3193
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
8 changes: 4 additions & 4 deletions monailabel/monaivista/lib/model/vista_point_2pt5/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,14 +238,14 @@ def update_slice(
continue

inputs = inputs_l[..., start_idx - (n_z_slices // 2) : start_idx + (n_z_slices // 2) + 1].permute(2, 0, 1)
if device and (device == "cuda" or isinstance(device, torch.device) and device.type == "cuda"):
if device and ((isinstance(device, str) and device.startswith('cuda')) or isinstance(device, torch.device) and device.type == "cuda"):
inputs = inputs.cuda()
data, unique_labels = prepare_sam_val_input(
inputs, class_prompts, point_prompts, start_idx, original_affine, device=device
)

predictor.eval()
if device == "cuda" or (isinstance(device, torch.device) and device.type == "cuda"):
if (isinstance(device, str) and device.startswith('cuda')) or (isinstance(device, torch.device) and device.type == "cuda"):
with torch.cuda.amp.autocast():
outputs = predictor(data)
logit = outputs[0]["high_res_logits"]
Expand Down Expand Up @@ -297,14 +297,14 @@ def iterate_all(
)
for start_idx in start_range:
inputs = inputs_l[..., start_idx - n_z_slices // 2 : start_idx + n_z_slices // 2 + 1].permute(2, 0, 1)
if device == "cuda" or (isinstance(device, torch.device) and device.type == "cuda"):
if (isinstance(device, str) and device.startswith('cuda')) or (isinstance(device, torch.device) and device.type == "cuda"):
inputs = inputs.cuda()
data, unique_labels = prepare_sam_val_input(inputs, class_prompts, point_prompts, start_idx, device=device)
predictor = predictor.eval()
with autocast():
if cachedEmbedding:
curr_embedding = cachedEmbedding[start_idx]
if device == "cuda" or (isinstance(device, torch.device) and device.type == "cuda"):
if (isinstance(device, str) and device.startswith('cuda')) or (isinstance(device, torch.device) and device.type == "cuda"):
curr_embedding = curr_embedding.cuda()
outputs = predictor.get_mask_prediction(data, curr_embedding)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def prepare_sam_val_input(inputs, class_prompts, point_prompts, start_idx, origi

class_list = [[i + 1] for i in class_prompts]
unique_labels = torch.tensor(class_list).long()
if device == "cuda" or (isinstance(device, torch.device) and device.type == "cuda"):
if (isinstance(device, str) and device.startswith('cuda')) or (isinstance(device, torch.device) and device.type == "cuda"):
unique_labels = unique_labels.cuda()

volume_point_coords = [cp for cp in foreground_all]
Expand Down Expand Up @@ -133,7 +133,7 @@ def prepare_sam_val_input(inputs, class_prompts, point_prompts, start_idx, origi
if point_coords:
point_coords = torch.tensor(point_coords).long()
point_labels = torch.tensor(point_labels).long()
if device == "cuda" or (isinstance(device, torch.device) and device.type == "cuda"):
if (isinstance(device, str) and device.startswith('cuda')) or (isinstance(device, torch.device) and device.type == "cuda"):
point_coords = point_coords.cuda()
point_labels = point_labels.cuda()

Expand Down

0 comments on commit 75c3193

Please sign in to comment.