Skip to content

Commit

Permalink
Extend iterative prompt generators to return prompts for 3d (#692)
Browse files Browse the repository at this point in the history
Extend support for iterative prompt generators to 3d
  • Loading branch information
anwai98 authored Sep 29, 2024
1 parent 20e89ab commit 1a555a4
Showing 1 changed file with 73 additions and 46 deletions.
119 changes: 73 additions & 46 deletions micro_sam/prompt_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,64 +252,82 @@ def __call__(
class IterativePromptGenerator(PromptGeneratorBase):
"""Generate point prompts from an instance segmentation iteratively.
"""
def _get_positive_points(self, pos_region, overlap_region):
def _get_positive_points(self, pos_region, overlap_region, is_3d):
positive_locations = [torch.where(pos_reg) for pos_reg in pos_region]
# we may have objects without a positive region (= missing true foreground)
# in this case we just sample a point where the model was already correct
# in this case we just sample a positive point where the model was already correct
positive_locations = [
torch.where(ovlp_reg) if len(pos_loc[0]) == 0 else pos_loc
for pos_loc, ovlp_reg in zip(positive_locations, overlap_region)
]
# we sample one location for each object in the batch
# we sample one positive location for each object in the batch
sampled_indices = [np.random.choice(len(pos_loc[0])) for pos_loc in positive_locations]
# get the corresponding coordinates (Note that we flip the axis order here due to the expected order of SAM)
pos_coordinates = [
[pos_loc[-1][idx], pos_loc[-2][idx]] for pos_loc, idx in zip(positive_locations, sampled_indices)
]
# get the corresponding coordinates (NOTE: we flip the axis order here due to the expected order of SAM)
if is_3d:
pos_coordinates = [
[pos_loc[-1][idx], pos_loc[-2][idx], pos_loc[-3][idx]]
for pos_loc, idx in zip(positive_locations, sampled_indices)
]
else:
pos_coordinates = [
[pos_loc[-1][idx], pos_loc[-2][idx]] for pos_loc, idx in zip(positive_locations, sampled_indices)
]

# make sure that we still have the correct batch size
assert len(pos_coordinates) == pos_region.shape[0]
pos_labels = [1] * len(pos_coordinates)

return pos_coordinates, pos_labels

# TODO get rid of this looped implementation and use proper batched computation instead
def _get_negative_points(self, negative_region_batched, true_object_batched):
device = negative_region_batched.device

negative_coordinates, negative_labels = [], []
for neg_region, true_object in zip(negative_region_batched, true_object_batched):

tmp_neg_loc = torch.where(neg_region)
if torch.stack(tmp_neg_loc).shape[-1] == 0:
tmp_true_loc = torch.where(true_object)
x_coords, y_coords = tmp_true_loc[1], tmp_true_loc[2]
bbox = torch.stack([torch.min(x_coords), torch.min(y_coords),
torch.max(x_coords) + 1, torch.max(y_coords) + 1])
bbox_mask = torch.zeros_like(true_object).squeeze(0)

custom_df = 3 # custom dilation factor to perform dilation by expanding the pixels of bbox
bbox_mask[max(bbox[0] - custom_df, 0): min(bbox[2] + custom_df, true_object.shape[-2]),
max(bbox[1] - custom_df, 0): min(bbox[3] + custom_df, true_object.shape[-1])] = 1
bbox_mask = bbox_mask[None].to(device)

background_mask = torch.abs(bbox_mask - true_object)
tmp_neg_loc = torch.where(background_mask)

# there is a chance that the object is small to not return a decent-sized bounding box
# hence we might not find points sometimes there as well, hence we sample points from true background
if torch.stack(tmp_neg_loc).shape[-1] == 0:
tmp_neg_loc = torch.where(true_object == 0)
def _get_negative_locations_in_obj_bbox(self, true_object, custom_df=3):
true_loc = torch.where(true_object)
bbox = torch.stack(
[torch.min(true_loc[1]), torch.min(true_loc[2]), torch.max(true_loc[1]) + 1, torch.max(true_loc[2]) + 1]
)

neg_index = np.random.choice(len(tmp_neg_loc[1]))
neg_coordinates = [tmp_neg_loc[1][neg_index], tmp_neg_loc[2][neg_index]]
neg_coordinates = neg_coordinates[::-1]
neg_labels = 0
# custom dilation factor to perform dilation by expanding the pixels of bbox
bbox_mask = torch.zeros_like(true_object).squeeze(0)
bbox_mask[
max(bbox[0] - custom_df, 0): min(bbox[2] + custom_df, true_object.shape[-2]),
max(bbox[1] - custom_df, 0): min(bbox[3] + custom_df, true_object.shape[-1])
] = 1
bbox_mask = bbox_mask[None].to(true_object.device)
background_mask = torch.abs(bbox_mask - true_object)
return torch.where(background_mask)

def _get_negative_points(self, neg_region, true_object, is_3d):
# we have a valid negative region (i.e. a valid region where the model could not generate prediction)
negative_locations = [torch.where(neg_reg) for neg_reg in neg_region]
# we may have objects without a negative region (= no rectifications required)
# in this case we sample a negative point in outer periphery of the object inside the bounding box.
negative_locations = [
self._get_negative_locations_in_obj_bbox(true_obj) if len(neg_loc[0]) == 0 else neg_loc
for neg_loc, true_obj in zip(negative_locations, true_object)
]
# there is a chance that the object is small to not return a decent-sized bounding box
# hence we might not find points sometimes there as well. therefore, we sample points from true background.
negative_locations = [
torch.where(true_obj == 0) if len(neg_loc[0]) == 0 else neg_loc
for neg_loc, true_obj in zip(negative_locations, true_object)
]
# we sample one negative location for each object in the batch
sampled_indices = [np.random.choice(len(neg_loc[0])) for neg_loc in negative_locations]
# get the corresponding coordinates (NOTE: we flip the axis order here due to the expected order of SAM)
if is_3d:
neg_coordinates = [
[neg_loc[-1][idx], neg_loc[-2][idx], neg_loc[-3][idx]]
for neg_loc, idx in zip(negative_locations, sampled_indices)
]
else:
neg_coordinates = [
[neg_loc[-1][idx], neg_loc[-2][idx]] for neg_loc, idx in zip(negative_locations, sampled_indices)
]

negative_coordinates.append(neg_coordinates)
negative_labels.append(neg_labels)
# make sure that we still have the correct batch size
assert len(neg_coordinates) == neg_region.shape[0]
neg_labels = [0] * len(neg_coordinates)

return negative_coordinates, negative_labels
return neg_coordinates, neg_labels

def __call__(
self,
Expand All @@ -320,24 +338,33 @@ def __call__(
"""Generate the prompts for each object iteratively in the segmentation.
Args:
The groundtruth segmentation. Expects a float tensor of shape NUM_OBJECTS x 1 x H x W.
The predicted objects. Epects a float tensor of the same shape as the segmentation.
segmentation: The groundtruth segmentation.
Expects a float tensor of shape (NUM_OBJECTS x 1 x H x W) or (NUM_OBJECTS x 1 x Z x H x W).
prediction: The predicted objects. Epects a float tensor of the same shape as the segmentation.
Returns:
The updated point prompt coordinates.
The updated point prompt labels.
"""
assert segmentation.shape == prediction.shape
device = prediction.device
assert segmentation.shape == prediction.shape, \
"The segmentation and prediction tensors should have the same shape."

if segmentation.ndim == 5: # masks in 3d must be tensors of shape NUM_OBJECTS x 1 x Z x H x W
is_3d = True
elif segmentation.ndim == 4: # masks in 2d must be tensors of shape NUM_OBJECTS x 1 x H x W
is_3d = False
else:
raise ValueError("The segmentation and prediction tensors should have either '4' or '5' dimensions.")

true_object = segmentation.to(device)
expected_diff = (prediction - true_object)
neg_region = (expected_diff == 1).to(torch.float32)
pos_region = (expected_diff == -1)
overlap_region = torch.logical_and(prediction == 1, true_object == 1).to(torch.float32)

pos_coordinates, pos_labels = self._get_positive_points(pos_region, overlap_region)
neg_coordinates, neg_labels = self._get_negative_points(neg_region, true_object)
pos_coordinates, pos_labels = self._get_positive_points(pos_region, overlap_region, is_3d)
neg_coordinates, neg_labels = self._get_negative_points(neg_region, true_object, is_3d)
assert len(pos_coordinates) == len(pos_labels) == len(neg_coordinates) == len(neg_labels)

pos_coordinates = torch.tensor(pos_coordinates)[:, None]
Expand Down

0 comments on commit 1a555a4

Please sign in to comment.