Skip to content

Commit

Permalink
Adapting Iterative Prompting to updated prompt generator
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 committed Oct 30, 2023
1 parent 411d48f commit d22a187
Showing 1 changed file with 35 additions and 33 deletions.
68 changes: 35 additions & 33 deletions micro_sam/evaluation/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def run_inference_with_prompts(
prompt_save_dir: Optional[Union[str, os.PathLike]] = None,
batch_size: int = 512,
) -> None:
"""Run segment anything inference for multiple images using prompts derived form groundtruth.
"""Run segment anything inference for multiple images using prompts derived from groundtruth.
Args:
predictor: The SegmentAnything predictor.
Expand Down Expand Up @@ -395,22 +395,14 @@ def _save_segmentation(masks, prediction_path):
imageio.imwrite(prediction_path, segmentation)


@torch.no_grad()
def _run_inference_with_iterative_prompting_for_image(
model,
image,
gt,
n_iterations,
device,
use_boxes,
prediction_paths,
batch_size,
model, image, gt, n_iterations, device, use_boxes, prediction_paths, batch_size
):
assert len(prediction_paths) == n_iterations, f"{len(prediction_paths)}, {n_iterations}"
to_sam_inputs = ConvertToSamInputs()

image = torch.from_numpy(
image[None, None] if image.ndim == 2 else image[None]
)
image = torch.from_numpy(image[None, None] if image.ndim == 2 else image[None])
gt = torch.from_numpy(gt[None].astype("int32"))

n_pos = 0 if use_boxes else 1
Expand All @@ -419,7 +411,7 @@ def _run_inference_with_iterative_prompting_for_image(
input_images = torch.stack([model.preprocess(x=x["image"].to(device)) for x in batched_inputs], dim=0)
image_embeddings = model.image_embeddings_oft(input_images)

multimasking = n_pos == 1
multimasking = (n_pos == 1)
prompt_generator = IterativePromptGenerator()

n_samples = len(sampled_ids[0])
Expand Down Expand Up @@ -462,44 +454,55 @@ def _run_inference_with_iterative_prompting_for_image(
masks = (masks > 0.5).to(torch.float32)
final_masks.append(masks)

for _pred, _gt, _inp, logits in zip(masks, sampled_binary_y, this_batched_inputs, logits_masks):
next_coords, next_labels, _, _ = prompt_generator(_gt, _pred)
updated_point_coords = torch.cat([_inp["point_coords"], next_coords], dim=1) \
if "point_coords" in _inp.keys() else next_coords
updated_point_labels = torch.cat([_inp["point_labels"], next_labels], dim=1) \
if "point_labels" in _inp.keys() else next_labels

_inp["point_coords"] = updated_point_coords
_inp["point_labels"] = updated_point_labels
_inp["mask_inputs"] = logits
all_next_point_coords, all_next_point_labels = prompt_generator(sampled_binary_y, masks)
for next_coords, next_labels, _inputs, _logits in zip(all_next_point_coords, all_next_point_labels,
this_batched_inputs, logits_masks):
_inputs["point_coords"] = torch.cat([_inputs["point_coords"], next_coords], dim=1) \
if "point_coords" in _inputs.keys() else next_coords
_inputs["point_labels"] = torch.cat([_inputs["point_labels"], next_labels], dim=1) \
if "point_labels" in _inputs.keys() else next_labels
_inputs["mask_inputs"] = _logits

final_masks = torch.cat(final_masks, dim=1)
_save_segmentation(final_masks, prediction_paths[iteration])


def run_inference_with_iterative_prompting(
checkpoint_path: Union[str, os.PathLike],
model_type: str,
image_paths: List[Union[str, os.PathLike]],
gt_paths: List[Union[str, os.PathLike]],
prediction_root: Union[str, os.PathLike],
use_boxes: bool,
model_type: str = "vit_b",
checkpoint_path: Optional[Union[str, os.PathLike]] = None,
device: Optional[str] = None,
n_iterations: int = 8,
batch_size: int = 32,
) -> None:
"""@private"""
"""Run segment anything inference for multiple images using prompts iteratively
derived from model outputs and groundtruth
Args:
image_paths: The image file paths
gt_paths: The ground-truth segmentation file paths
prediction_root: TODO
use_box: Whether to use box prompts
model_type: Name of the vision transformer to be used for sam inference
checkpoint_path: The path to SAM model checkpoints
device: The device specification to enable GPU usage
n_iterations: The number of iterations to perform for the iterative prompting strategy
batch_size: The batch size used for batched predictions
"""
warnings.warn("The iterative prompting functionality is not working correctly yet.")

device = util._get_device(device)
model = get_trainable_sam_model(model_type, checkpoint_path)

# create all prediction folders
# create all prediction folders for all intermediate iterations
for i in range(n_iterations):
os.makedirs(os.path.join(prediction_root, f"iteration{i:02}"), exist_ok=True)

for image_path, gt_path in tqdm(
zip(image_paths, gt_paths), total=len(image_paths), desc="Run inference with prompts"
zip(image_paths, gt_paths), total=len(image_paths), desc="Run inference with iterative prompting for all images"
):
image_name = os.path.basename(image_path)

Expand All @@ -511,10 +514,9 @@ def run_inference_with_iterative_prompting(
assert os.path.exists(gt_path), gt_path

image = imageio.imread(image_path)
gt = imageio.imread(gt_path).astype("uint32")
gt = imageio.imread(gt_path)
gt = relabel_sequential(gt)[0]

with torch.no_grad():
_run_inference_with_iterative_prompting_for_image(
model, image, gt, n_iterations, device, use_boxes, prediction_paths, batch_size,
)
_run_inference_with_iterative_prompting_for_image(
model, image, gt, n_iterations, device, use_boxes, prediction_paths, batch_size,
)

0 comments on commit d22a187

Please sign in to comment.