From 7b59778afc1141c2a09900748c19e2b6d3f7e995 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sun, 29 Oct 2023 13:05:02 +0100 Subject: [PATCH] Add support for batched inference with prompts --- micro_sam/evaluation/inference.py | 95 ++++++------------- micro_sam/inference.py | 146 +++++++++++++++++++++++++++++ micro_sam/instance_segmentation.py | 7 +- 3 files changed, 179 insertions(+), 69 deletions(-) create mode 100644 micro_sam/inference.py diff --git a/micro_sam/evaluation/inference.py b/micro_sam/evaluation/inference.py index 5d8943ce..93ae61ad 100644 --- a/micro_sam/evaluation/inference.py +++ b/micro_sam/evaluation/inference.py @@ -16,9 +16,9 @@ from tqdm import tqdm from segment_anything import SamPredictor -from segment_anything.utils.transforms import ResizeLongestSide from .. import util as util +from ..inference import batched_inference from ..instance_segmentation import mask_data_to_segmentation from ..prompt_generators import PointAndBoxPromptGenerator, IterativePromptGenerator from ..training import get_trainable_sam_model, ConvertToSamInputs @@ -72,7 +72,6 @@ def _get_batched_prompts( n_positives, n_negatives, dilation, - transform_function, ): # Initialize the prompt generator. prompt_generator = PointAndBoxPromptGenerator( @@ -87,25 +86,21 @@ def _get_batched_prompts( bbox_coordinates = [bbox_coordinates[gt_id] for gt_id in gt_ids] masks = util.segmentation_to_one_hot(gt.astype("int64"), gt_ids) - input_points, input_labels, input_boxes, _ = prompt_generator( + points, point_labels, boxes, _ = prompt_generator( masks, bbox_coordinates, center_coordinates ) - # apply the transforms to the points and boxes - if use_boxes: - input_boxes = torch.from_numpy( - transform_function.apply_boxes(input_boxes.numpy(), gt.shape) - ) - if use_points: - input_points = torch.from_numpy( - transform_function.apply_coords(input_points.numpy(), gt.shape) - ) + def to_numpy(x): + if x is None: + return x + return x.numpy() - return input_points, input_labels, input_boxes + return to_numpy(points), to_numpy(point_labels), to_numpy(boxes) def _run_inference_with_prompts_for_image( predictor, + image, gt, use_points, use_boxes, @@ -114,62 +109,30 @@ def _run_inference_with_prompts_for_image( dilation, batch_size, cached_prompts, + embedding_path, ): - # We need the resize transformation for the expected model input size. - transform_function = ResizeLongestSide(1024) gt_ids = np.unique(gt)[1:] - if cached_prompts is None: - input_points, input_labels, input_boxes = _get_batched_prompts( - gt, gt_ids, use_points, use_boxes, n_positives, n_negatives, dilation, transform_function, + points, point_labels, boxes = _get_batched_prompts( + gt, gt_ids, use_points, use_boxes, n_positives, n_negatives, dilation ) else: - input_points, input_labels, input_boxes = cached_prompts + points, point_labels, boxes = cached_prompts # Make a copy of the point prompts to return them at the end. - prompts = deepcopy((input_points, input_labels, input_boxes)) - - # Transform the prompts into batches - device = predictor.device - input_points = None if input_points is None else torch.tensor(np.array(input_points), dtype=torch.float32).to(device) - input_labels = None if input_labels is None else torch.tensor(np.array(input_labels), dtype=torch.float32).to(device) - input_boxes = None if input_boxes is None else torch.tensor(np.array(input_boxes), dtype=torch.float32).to(device) + prompts = deepcopy((points, point_labels, boxes)) # Use multi-masking only if we have a single positive point without box multimasking = False if not use_boxes and (n_positives == 1 and n_negatives == 0): multimasking = True - # Run the batched inference. - n_samples = input_boxes.shape[0] if input_points is None else input_points.shape[0] - n_batches = int(np.ceil(float(n_samples) / batch_size)) - masks, ious = [], [] - with torch.no_grad(): - for batch_idx in range(n_batches): - batch_start = batch_idx * batch_size - batch_stop = min((batch_idx + 1) * batch_size, n_samples) - - batch_points = None if input_points is None else input_points[batch_start:batch_stop] - batch_labels = None if input_labels is None else input_labels[batch_start:batch_stop] - batch_boxes = None if input_boxes is None else input_boxes[batch_start:batch_stop] - - batch_masks, batch_ious, _ = predictor.predict_torch( - point_coords=batch_points, point_labels=batch_labels, - boxes=batch_boxes, multimask_output=multimasking - ) - masks.append(batch_masks) - ious.append(batch_ious) - masks = torch.cat(masks) - ious = torch.cat(ious) - assert len(masks) == len(ious) == n_samples - - # TODO we should actually use non-max suppression here - # I will implement it somewhere to have it refactored - instance_labels = np.zeros_like(gt, dtype=int) - for m, iou, gt_idx in zip(masks, ious, gt_ids): - best_idx = torch.argmax(iou) - best_mask = m[best_idx] - instance_labels[best_mask.detach().cpu().numpy()] = gt_idx + instance_labels = batched_inference( + predictor, image, batch_size, + boxes=boxes, points=points, point_labels=point_labels, + multimasking=multimasking, embedding_path=embedding_path, + return_instance_segmentation=True, + ) return instance_labels, prompts @@ -203,7 +166,9 @@ def get_predictor( ) else: # Vanilla SAM model assert not return_state - predictor = util.get_sam_model(model_type=model_type, device=device, checkpoint_path=checkpoint_path) # type: ignore + predictor = util.get_sam_model( + model_type=model_type, device=device, checkpoint_path=checkpoint_path + ) # type: ignore return predictor @@ -228,7 +193,7 @@ def precompute_all_embeddings( util.precompute_image_embeddings(predictor, im, embedding_path, ndim=2) -def _precompute_prompts(gt_path, use_points, use_boxes, n_positives, n_negatives, dilation, transform_function): +def _precompute_prompts(gt_path, use_points, use_boxes, n_positives, n_negatives, dilation): name = os.path.basename(gt_path) gt = imageio.imread(gt_path).astype("uint32") @@ -236,7 +201,7 @@ def _precompute_prompts(gt_path, use_points, use_boxes, n_positives, n_negatives gt_ids = np.unique(gt)[1:] input_point, input_label, input_box = _get_batched_prompts( - gt, gt_ids, use_points, use_boxes, n_positives, n_negatives, dilation, transform_function + gt, gt_ids, use_points, use_boxes, n_positives, n_negatives, dilation ) if use_boxes and not use_points: @@ -259,7 +224,6 @@ def precompute_all_prompts( prompt_settings: The settings for which the prompts will be computed. """ os.makedirs(prompt_save_dir, exist_ok=True) - transform_function = ResizeLongestSide(1024) for settings in tqdm(prompt_settings, desc="Precompute prompts"): @@ -284,7 +248,6 @@ def precompute_all_prompts( n_positives=n_positives, n_negatives=n_negatives, dilation=dilation, - transform_function=transform_function, ) results.append(prompts) @@ -353,7 +316,7 @@ def run_inference_with_prompts( gt_paths: The ground-truth segmentation file paths. embedding_dir: The directory where the image embddings will be saved or are already saved. use_points: Whether to use point prompts. - use_boxes: Whetehr to use box prompts + use_boxes: Whether to use box prompts n_positives: The number of positive point prompts that will be sampled. n_negativess: The number of negative point prompts that will be sampled. dilation: The dilation factor for the radius around the ground-truth object @@ -393,18 +356,16 @@ def run_inference_with_prompts( gt = relabel_sequential(gt)[0] embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") - image_embeddings = util.precompute_image_embeddings(predictor, im, embedding_path, ndim=2) - util.set_precomputed(predictor, image_embeddings) - this_prompts, cached_point_prompts, cached_box_prompts = _load_prompts( cached_point_prompts, save_point_prompts, cached_box_prompts, save_box_prompts, label_name ) instances, this_prompts = _run_inference_with_prompts_for_image( - predictor, gt, n_positives=n_positives, n_negatives=n_negatives, + predictor, im, gt, n_positives=n_positives, n_negatives=n_negatives, dilation=dilation, use_points=use_points, use_boxes=use_boxes, - batch_size=batch_size, cached_prompts=this_prompts + batch_size=batch_size, cached_prompts=this_prompts, + embedding_path=embedding_path, ) if save_point_prompts: diff --git a/micro_sam/inference.py b/micro_sam/inference.py new file mode 100644 index 00000000..b4b06dca --- /dev/null +++ b/micro_sam/inference.py @@ -0,0 +1,146 @@ +import os +from typing import Optional, Union + +import torch +import numpy as np + +import segment_anything.utils.amg as amg_utils +from segment_anything import SamPredictor +from segment_anything.utils.transforms import ResizeLongestSide + +from . import util +from .instance_segmentation import mask_data_to_segmentation +from ._vendored import batched_mask_to_box + + +@torch.no_grad() +def batched_inference( + predictor: SamPredictor, + image: np.ndarray, + batch_size: int, + boxes: Optional[np.ndarray] = None, + points: Optional[np.ndarray] = None, + point_labels: Optional[np.ndarray] = None, + multimasking: bool = False, + embedding_path: Optional[Union[str, os.PathLike]] = None, + return_instance_segmentation: bool = True, + segmentation_ids: Optional[list] = None, +): + """Run batched inference for input prompts. + + Args: + predictor: The segment anything predictor. + image: The input image. + batch_size: The batch size to use for inference. + boxes: The box prompts. Array of shape N_PROMPTS x 4. + The bounding boxes are represented by [MIN_X, MIN_Y, MAX_X, MAX_Y]. + points: The point prompt coordinates. Array of shape N_PROMPTS x 2. + The points are represented by [X, Y]. + point_labels: The point prompt labels. Array of shape N_PROMPTS x 1. + The labels are either 0 (negative prompt) or 1 (positive prompt). + multimasking: Whether to predict with 3 or 1 mask. + embedding_path: Cache path for the image embeddings. + return_instance_segmentation: Whether to return a instance segmentation + or the individual mask data. + segmentation_ids: Fixed segmentation ids to assign to the masks + derived from the prompts. + + Returns: + The predicted segmentation masks. + """ + if multimasking and (segmentation_ids is not None) and (not return_instance_segmentation): + raise NotImplementedError + + if (points is None) != (point_labels is None): + raise ValueError( + "If you have point prompts both `points` and `point_labels` have to be passed, " + "but you passed only one of them." + ) + + have_points = points is not None + have_boxes = boxes is not None + if (not have_points) and (not have_boxes): + raise ValueError("Point and/or box prompts have to be passed, you passed neither.") + + if have_points and (len(point_labels) != len(points)): + raise ValueError( + "The number of point coordinates and labels does not match: " + f"{len(point_labels)} != {len(points)}" + ) + + if (have_points and have_boxes) and (len(points) != len(boxes)): + raise ValueError( + "The number of point and box prompts does not match: " + f"{len(points)} != {len(boxes)}" + ) + n_prompts = boxes.shape[0] if have_boxes else points.shape[0] + + if (segmentation_ids is not None) and (len(segmentation_ids) != n_prompts): + raise ValueError( + "The number of segmentation ids and prompts does not match: " + f"{len(segmentation_ids)} != {n_prompts}" + ) + + # Compute the image embeddings. + image_embeddings = util.precompute_image_embeddings(predictor, image, embedding_path, ndim=2) + util.set_precomputed(predictor, image_embeddings) + + # Determine the number of batches. + n_batches = int(np.ceil(float(n_prompts) / batch_size)) + + # Preprocess the prompts. + device = predictor.device + transform_function = ResizeLongestSide(1024) + image_shape = predictor.original_size + if have_boxes: + boxes = transform_function.apply_boxes(boxes, image_shape) + boxes = torch.tensor(boxes, dtype=torch.float32).to(device) + if have_points: + points = transform_function.apply_coords(points, image_shape) + points = torch.tensor(points, dtype=torch.float32).to(device) + point_labels = torch.tensor(point_labels, dtype=torch.float32).to(device) + + masks = amg_utils.MaskData() + for batch_idx in range(n_batches): + batch_start = batch_idx * batch_size + batch_stop = min((batch_idx + 1) * batch_size, n_prompts) + + batch_boxes = boxes[batch_start:batch_stop] if have_boxes else None + batch_points = points[batch_start:batch_stop] if have_points else None + batch_labels = point_labels[batch_start:batch_stop] if have_points else None + + batch_masks, batch_ious, _ = predictor.predict_torch( + point_coords=batch_points, point_labels=batch_labels, + boxes=batch_boxes, multimask_output=multimasking + ) + + # If we return the merged instance segmentation and use multi-masking, + # then we need to select the most likely mask (according to the predicted IOU) here. + if return_instance_segmentation and multimasking: + _, max_index = batch_ious.max(axis=1) + # How can this be vectorized??? + batch_masks = torch.cat([batch_masks[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1) + batch_ious = torch.cat([batch_ious[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1) + + batch_data = amg_utils.MaskData(masks=batch_masks.flatten(0, 1), iou_preds=batch_ious.flatten(0, 1)) + batch_data["masks"] = (batch_data["masks"] > predictor.model.mask_threshold).type(torch.bool) + batch_data["boxes"] = batched_mask_to_box(batch_data["masks"]) + + masks.cat(batch_data) + + # Mask data to records. + masks = [ + { + "segmentation": masks["masks"][idx], + "area": masks["masks"][idx].sum(), + "bbox": amg_utils.box_xyxy_to_xywh(masks["boxes"][idx]).tolist(), + "predicted_iou": masks["iou_preds"][idx].item(), + "seg_id": idx + 1 if segmentation_ids is None else int(segmentation_ids[idx]), + } + for idx in range(len(masks["masks"])) + ] + + if return_instance_segmentation: + masks = mask_data_to_segmentation(masks, image_shape, with_background=False, min_object_size=0) + + return masks diff --git a/micro_sam/instance_segmentation.py b/micro_sam/instance_segmentation.py index 14cdc0aa..0173d811 100644 --- a/micro_sam/instance_segmentation.py +++ b/micro_sam/instance_segmentation.py @@ -74,8 +74,11 @@ def mask_data_to_segmentation( for mask in masks: if mask["area"] < min_object_size: continue - segmentation[mask["segmentation"]] = seg_id - seg_id += 1 + + this_seg_id = mask.get("seg_id", seg_id) + segmentation[mask["segmentation"]] = this_seg_id + + seg_id = this_seg_id + 1 if with_background: seg_ids, sizes = np.unique(segmentation, return_counts=True)