Skip to content

Commit

Permalink
Merge pull request #253 from computational-cell-analytics/inference
Browse files Browse the repository at this point in the history
Add support for batched inference with prompts
  • Loading branch information
constantinpape authored Oct 29, 2023
2 parents 579156e + 7b59778 commit 98a3e40
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 69 deletions.
95 changes: 28 additions & 67 deletions micro_sam/evaluation/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
from tqdm import tqdm

from segment_anything import SamPredictor
from segment_anything.utils.transforms import ResizeLongestSide

from .. import util as util
from ..inference import batched_inference
from ..instance_segmentation import mask_data_to_segmentation
from ..prompt_generators import PointAndBoxPromptGenerator, IterativePromptGenerator
from ..training import get_trainable_sam_model, ConvertToSamInputs
Expand Down Expand Up @@ -72,7 +72,6 @@ def _get_batched_prompts(
n_positives,
n_negatives,
dilation,
transform_function,
):
# Initialize the prompt generator.
prompt_generator = PointAndBoxPromptGenerator(
Expand All @@ -87,25 +86,21 @@ def _get_batched_prompts(
bbox_coordinates = [bbox_coordinates[gt_id] for gt_id in gt_ids]
masks = util.segmentation_to_one_hot(gt.astype("int64"), gt_ids)

input_points, input_labels, input_boxes, _ = prompt_generator(
points, point_labels, boxes, _ = prompt_generator(
masks, bbox_coordinates, center_coordinates
)

# apply the transforms to the points and boxes
if use_boxes:
input_boxes = torch.from_numpy(
transform_function.apply_boxes(input_boxes.numpy(), gt.shape)
)
if use_points:
input_points = torch.from_numpy(
transform_function.apply_coords(input_points.numpy(), gt.shape)
)
def to_numpy(x):
if x is None:
return x
return x.numpy()

return input_points, input_labels, input_boxes
return to_numpy(points), to_numpy(point_labels), to_numpy(boxes)


def _run_inference_with_prompts_for_image(
predictor,
image,
gt,
use_points,
use_boxes,
Expand All @@ -114,62 +109,30 @@ def _run_inference_with_prompts_for_image(
dilation,
batch_size,
cached_prompts,
embedding_path,
):
# We need the resize transformation for the expected model input size.
transform_function = ResizeLongestSide(1024)
gt_ids = np.unique(gt)[1:]

if cached_prompts is None:
input_points, input_labels, input_boxes = _get_batched_prompts(
gt, gt_ids, use_points, use_boxes, n_positives, n_negatives, dilation, transform_function,
points, point_labels, boxes = _get_batched_prompts(
gt, gt_ids, use_points, use_boxes, n_positives, n_negatives, dilation
)
else:
input_points, input_labels, input_boxes = cached_prompts
points, point_labels, boxes = cached_prompts

# Make a copy of the point prompts to return them at the end.
prompts = deepcopy((input_points, input_labels, input_boxes))

# Transform the prompts into batches
device = predictor.device
input_points = None if input_points is None else torch.tensor(np.array(input_points), dtype=torch.float32).to(device)
input_labels = None if input_labels is None else torch.tensor(np.array(input_labels), dtype=torch.float32).to(device)
input_boxes = None if input_boxes is None else torch.tensor(np.array(input_boxes), dtype=torch.float32).to(device)
prompts = deepcopy((points, point_labels, boxes))

# Use multi-masking only if we have a single positive point without box
multimasking = False
if not use_boxes and (n_positives == 1 and n_negatives == 0):
multimasking = True

# Run the batched inference.
n_samples = input_boxes.shape[0] if input_points is None else input_points.shape[0]
n_batches = int(np.ceil(float(n_samples) / batch_size))
masks, ious = [], []
with torch.no_grad():
for batch_idx in range(n_batches):
batch_start = batch_idx * batch_size
batch_stop = min((batch_idx + 1) * batch_size, n_samples)

batch_points = None if input_points is None else input_points[batch_start:batch_stop]
batch_labels = None if input_labels is None else input_labels[batch_start:batch_stop]
batch_boxes = None if input_boxes is None else input_boxes[batch_start:batch_stop]

batch_masks, batch_ious, _ = predictor.predict_torch(
point_coords=batch_points, point_labels=batch_labels,
boxes=batch_boxes, multimask_output=multimasking
)
masks.append(batch_masks)
ious.append(batch_ious)
masks = torch.cat(masks)
ious = torch.cat(ious)
assert len(masks) == len(ious) == n_samples

# TODO we should actually use non-max suppression here
# I will implement it somewhere to have it refactored
instance_labels = np.zeros_like(gt, dtype=int)
for m, iou, gt_idx in zip(masks, ious, gt_ids):
best_idx = torch.argmax(iou)
best_mask = m[best_idx]
instance_labels[best_mask.detach().cpu().numpy()] = gt_idx
instance_labels = batched_inference(
predictor, image, batch_size,
boxes=boxes, points=points, point_labels=point_labels,
multimasking=multimasking, embedding_path=embedding_path,
return_instance_segmentation=True,
)

return instance_labels, prompts

Expand Down Expand Up @@ -203,7 +166,9 @@ def get_predictor(
)
else: # Vanilla SAM model
assert not return_state
predictor = util.get_sam_model(model_type=model_type, device=device, checkpoint_path=checkpoint_path) # type: ignore
predictor = util.get_sam_model(
model_type=model_type, device=device, checkpoint_path=checkpoint_path
) # type: ignore
return predictor


Expand All @@ -228,15 +193,15 @@ def precompute_all_embeddings(
util.precompute_image_embeddings(predictor, im, embedding_path, ndim=2)


def _precompute_prompts(gt_path, use_points, use_boxes, n_positives, n_negatives, dilation, transform_function):
def _precompute_prompts(gt_path, use_points, use_boxes, n_positives, n_negatives, dilation):
name = os.path.basename(gt_path)

gt = imageio.imread(gt_path).astype("uint32")
gt = relabel_sequential(gt)[0]
gt_ids = np.unique(gt)[1:]

input_point, input_label, input_box = _get_batched_prompts(
gt, gt_ids, use_points, use_boxes, n_positives, n_negatives, dilation, transform_function
gt, gt_ids, use_points, use_boxes, n_positives, n_negatives, dilation
)

if use_boxes and not use_points:
Expand All @@ -259,7 +224,6 @@ def precompute_all_prompts(
prompt_settings: The settings for which the prompts will be computed.
"""
os.makedirs(prompt_save_dir, exist_ok=True)
transform_function = ResizeLongestSide(1024)

for settings in tqdm(prompt_settings, desc="Precompute prompts"):

Expand All @@ -284,7 +248,6 @@ def precompute_all_prompts(
n_positives=n_positives,
n_negatives=n_negatives,
dilation=dilation,
transform_function=transform_function,
)
results.append(prompts)

Expand Down Expand Up @@ -353,7 +316,7 @@ def run_inference_with_prompts(
gt_paths: The ground-truth segmentation file paths.
embedding_dir: The directory where the image embddings will be saved or are already saved.
use_points: Whether to use point prompts.
use_boxes: Whetehr to use box prompts
use_boxes: Whether to use box prompts
n_positives: The number of positive point prompts that will be sampled.
n_negativess: The number of negative point prompts that will be sampled.
dilation: The dilation factor for the radius around the ground-truth object
Expand Down Expand Up @@ -393,18 +356,16 @@ def run_inference_with_prompts(
gt = relabel_sequential(gt)[0]

embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
image_embeddings = util.precompute_image_embeddings(predictor, im, embedding_path, ndim=2)
util.set_precomputed(predictor, image_embeddings)

this_prompts, cached_point_prompts, cached_box_prompts = _load_prompts(
cached_point_prompts, save_point_prompts,
cached_box_prompts, save_box_prompts,
label_name
)
instances, this_prompts = _run_inference_with_prompts_for_image(
predictor, gt, n_positives=n_positives, n_negatives=n_negatives,
predictor, im, gt, n_positives=n_positives, n_negatives=n_negatives,
dilation=dilation, use_points=use_points, use_boxes=use_boxes,
batch_size=batch_size, cached_prompts=this_prompts
batch_size=batch_size, cached_prompts=this_prompts,
embedding_path=embedding_path,
)

if save_point_prompts:
Expand Down
146 changes: 146 additions & 0 deletions micro_sam/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import os
from typing import Optional, Union

import torch
import numpy as np

import segment_anything.utils.amg as amg_utils
from segment_anything import SamPredictor
from segment_anything.utils.transforms import ResizeLongestSide

from . import util
from .instance_segmentation import mask_data_to_segmentation
from ._vendored import batched_mask_to_box


@torch.no_grad()
def batched_inference(
predictor: SamPredictor,
image: np.ndarray,
batch_size: int,
boxes: Optional[np.ndarray] = None,
points: Optional[np.ndarray] = None,
point_labels: Optional[np.ndarray] = None,
multimasking: bool = False,
embedding_path: Optional[Union[str, os.PathLike]] = None,
return_instance_segmentation: bool = True,
segmentation_ids: Optional[list] = None,
):
"""Run batched inference for input prompts.
Args:
predictor: The segment anything predictor.
image: The input image.
batch_size: The batch size to use for inference.
boxes: The box prompts. Array of shape N_PROMPTS x 4.
The bounding boxes are represented by [MIN_X, MIN_Y, MAX_X, MAX_Y].
points: The point prompt coordinates. Array of shape N_PROMPTS x 2.
The points are represented by [X, Y].
point_labels: The point prompt labels. Array of shape N_PROMPTS x 1.
The labels are either 0 (negative prompt) or 1 (positive prompt).
multimasking: Whether to predict with 3 or 1 mask.
embedding_path: Cache path for the image embeddings.
return_instance_segmentation: Whether to return a instance segmentation
or the individual mask data.
segmentation_ids: Fixed segmentation ids to assign to the masks
derived from the prompts.
Returns:
The predicted segmentation masks.
"""
if multimasking and (segmentation_ids is not None) and (not return_instance_segmentation):
raise NotImplementedError

if (points is None) != (point_labels is None):
raise ValueError(
"If you have point prompts both `points` and `point_labels` have to be passed, "
"but you passed only one of them."
)

have_points = points is not None
have_boxes = boxes is not None
if (not have_points) and (not have_boxes):
raise ValueError("Point and/or box prompts have to be passed, you passed neither.")

if have_points and (len(point_labels) != len(points)):
raise ValueError(
"The number of point coordinates and labels does not match: "
f"{len(point_labels)} != {len(points)}"
)

if (have_points and have_boxes) and (len(points) != len(boxes)):
raise ValueError(
"The number of point and box prompts does not match: "
f"{len(points)} != {len(boxes)}"
)
n_prompts = boxes.shape[0] if have_boxes else points.shape[0]

if (segmentation_ids is not None) and (len(segmentation_ids) != n_prompts):
raise ValueError(
"The number of segmentation ids and prompts does not match: "
f"{len(segmentation_ids)} != {n_prompts}"
)

# Compute the image embeddings.
image_embeddings = util.precompute_image_embeddings(predictor, image, embedding_path, ndim=2)
util.set_precomputed(predictor, image_embeddings)

# Determine the number of batches.
n_batches = int(np.ceil(float(n_prompts) / batch_size))

# Preprocess the prompts.
device = predictor.device
transform_function = ResizeLongestSide(1024)
image_shape = predictor.original_size
if have_boxes:
boxes = transform_function.apply_boxes(boxes, image_shape)
boxes = torch.tensor(boxes, dtype=torch.float32).to(device)
if have_points:
points = transform_function.apply_coords(points, image_shape)
points = torch.tensor(points, dtype=torch.float32).to(device)
point_labels = torch.tensor(point_labels, dtype=torch.float32).to(device)

masks = amg_utils.MaskData()
for batch_idx in range(n_batches):
batch_start = batch_idx * batch_size
batch_stop = min((batch_idx + 1) * batch_size, n_prompts)

batch_boxes = boxes[batch_start:batch_stop] if have_boxes else None
batch_points = points[batch_start:batch_stop] if have_points else None
batch_labels = point_labels[batch_start:batch_stop] if have_points else None

batch_masks, batch_ious, _ = predictor.predict_torch(
point_coords=batch_points, point_labels=batch_labels,
boxes=batch_boxes, multimask_output=multimasking
)

# If we return the merged instance segmentation and use multi-masking,
# then we need to select the most likely mask (according to the predicted IOU) here.
if return_instance_segmentation and multimasking:
_, max_index = batch_ious.max(axis=1)
# How can this be vectorized???
batch_masks = torch.cat([batch_masks[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1)
batch_ious = torch.cat([batch_ious[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1)

batch_data = amg_utils.MaskData(masks=batch_masks.flatten(0, 1), iou_preds=batch_ious.flatten(0, 1))
batch_data["masks"] = (batch_data["masks"] > predictor.model.mask_threshold).type(torch.bool)
batch_data["boxes"] = batched_mask_to_box(batch_data["masks"])

masks.cat(batch_data)

# Mask data to records.
masks = [
{
"segmentation": masks["masks"][idx],
"area": masks["masks"][idx].sum(),
"bbox": amg_utils.box_xyxy_to_xywh(masks["boxes"][idx]).tolist(),
"predicted_iou": masks["iou_preds"][idx].item(),
"seg_id": idx + 1 if segmentation_ids is None else int(segmentation_ids[idx]),
}
for idx in range(len(masks["masks"]))
]

if return_instance_segmentation:
masks = mask_data_to_segmentation(masks, image_shape, with_background=False, min_object_size=0)

return masks
7 changes: 5 additions & 2 deletions micro_sam/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,11 @@ def mask_data_to_segmentation(
for mask in masks:
if mask["area"] < min_object_size:
continue
segmentation[mask["segmentation"]] = seg_id
seg_id += 1

this_seg_id = mask.get("seg_id", seg_id)
segmentation[mask["segmentation"]] = this_seg_id

seg_id = this_seg_id + 1

if with_background:
seg_ids, sizes = np.unique(segmentation, return_counts=True)
Expand Down

0 comments on commit 98a3e40

Please sign in to comment.