From b8fcaa59d29fde67302f7f50041d053375264ce6 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Mon, 30 Oct 2023 13:28:19 +0100 Subject: [PATCH] Fix instance segmentation on GPU and CPU --- micro_sam/instance_segmentation.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/micro_sam/instance_segmentation.py b/micro_sam/instance_segmentation.py index 9df529e6..c05bbaa8 100644 --- a/micro_sam/instance_segmentation.py +++ b/micro_sam/instance_segmentation.py @@ -70,13 +70,16 @@ def mask_data_to_segmentation( masks = sorted(masks, key=(lambda x: x["area"]), reverse=True) segmentation = np.zeros(shape[:2], dtype="uint32") + def require_numpy(mask): + return mask.cpu().numpy() if torch.is_tensor(mask) else mask + seg_id = 1 for mask in masks: if mask["area"] < min_object_size: continue this_seg_id = mask.get("seg_id", seg_id) - segmentation[mask["segmentation"].cpu()] = this_seg_id + segmentation[require_numpy(mask["segmentation"])] = this_seg_id seg_id = this_seg_id + 1