diff --git a/micro_sam/evaluation/inference.py b/micro_sam/evaluation/inference.py index 93ae61ad..cdabebb7 100644 --- a/micro_sam/evaluation/inference.py +++ b/micro_sam/evaluation/inference.py @@ -308,7 +308,7 @@ def run_inference_with_prompts( prompt_save_dir: Optional[Union[str, os.PathLike]] = None, batch_size: int = 512, ) -> None: - """Run segment anything inference for multiple images using prompts derived form groundtruth. + """Run segment anything inference for multiple images using prompts derived from groundtruth. Args: predictor: The SegmentAnything predictor. @@ -395,22 +395,14 @@ def _save_segmentation(masks, prediction_path): imageio.imwrite(prediction_path, segmentation) +@torch.no_grad() def _run_inference_with_iterative_prompting_for_image( - model, - image, - gt, - n_iterations, - device, - use_boxes, - prediction_paths, - batch_size, + model, image, gt, n_iterations, device, use_boxes, prediction_paths, batch_size ): assert len(prediction_paths) == n_iterations, f"{len(prediction_paths)}, {n_iterations}" to_sam_inputs = ConvertToSamInputs() - image = torch.from_numpy( - image[None, None] if image.ndim == 2 else image[None] - ) + image = torch.from_numpy(image[None, None] if image.ndim == 2 else image[None]) gt = torch.from_numpy(gt[None].astype("int32")) n_pos = 0 if use_boxes else 1 @@ -419,7 +411,7 @@ def _run_inference_with_iterative_prompting_for_image( input_images = torch.stack([model.preprocess(x=x["image"].to(device)) for x in batched_inputs], dim=0) image_embeddings = model.image_embeddings_oft(input_images) - multimasking = n_pos == 1 + multimasking = (n_pos == 1) prompt_generator = IterativePromptGenerator() n_samples = len(sampled_ids[0]) @@ -462,44 +454,55 @@ def _run_inference_with_iterative_prompting_for_image( masks = (masks > 0.5).to(torch.float32) final_masks.append(masks) - for _pred, _gt, _inp, logits in zip(masks, sampled_binary_y, this_batched_inputs, logits_masks): - next_coords, next_labels, _, _ = prompt_generator(_gt, _pred) - updated_point_coords = torch.cat([_inp["point_coords"], next_coords], dim=1) \ - if "point_coords" in _inp.keys() else next_coords - updated_point_labels = torch.cat([_inp["point_labels"], next_labels], dim=1) \ - if "point_labels" in _inp.keys() else next_labels - - _inp["point_coords"] = updated_point_coords - _inp["point_labels"] = updated_point_labels - _inp["mask_inputs"] = logits + all_next_point_coords, all_next_point_labels = prompt_generator(sampled_binary_y, masks) + for next_coords, next_labels, _inputs, _logits in zip(all_next_point_coords, all_next_point_labels, + this_batched_inputs, logits_masks): + _inputs["point_coords"] = torch.cat([_inputs["point_coords"], next_coords], dim=1) \ + if "point_coords" in _inputs.keys() else next_coords + _inputs["point_labels"] = torch.cat([_inputs["point_labels"], next_labels], dim=1) \ + if "point_labels" in _inputs.keys() else next_labels + _inputs["mask_inputs"] = _logits final_masks = torch.cat(final_masks, dim=1) _save_segmentation(final_masks, prediction_paths[iteration]) def run_inference_with_iterative_prompting( - checkpoint_path: Union[str, os.PathLike], - model_type: str, image_paths: List[Union[str, os.PathLike]], gt_paths: List[Union[str, os.PathLike]], prediction_root: Union[str, os.PathLike], use_boxes: bool, + model_type: str = "vit_b", + checkpoint_path: Optional[Union[str, os.PathLike]] = None, device: Optional[str] = None, n_iterations: int = 8, batch_size: int = 32, ) -> None: - """@private""" + """Run segment anything inference for multiple images using prompts iteratively + derived from model outputs and groundtruth + + Args: + image_paths: The image file paths + gt_paths: The ground-truth segmentation file paths + prediction_root: TODO + use_box: Whether to use box prompts + model_type: Name of the vision transformer to be used for sam inference + checkpoint_path: The path to SAM model checkpoints + device: The device specification to enable GPU usage + n_iterations: The number of iterations to perform for the iterative prompting strategy + batch_size: The batch size used for batched predictions + """ warnings.warn("The iterative prompting functionality is not working correctly yet.") device = util._get_device(device) model = get_trainable_sam_model(model_type, checkpoint_path) - # create all prediction folders + # create all prediction folders for all intermediate iterations for i in range(n_iterations): os.makedirs(os.path.join(prediction_root, f"iteration{i:02}"), exist_ok=True) for image_path, gt_path in tqdm( - zip(image_paths, gt_paths), total=len(image_paths), desc="Run inference with prompts" + zip(image_paths, gt_paths), total=len(image_paths), desc="Run inference with iterative prompting for all images" ): image_name = os.path.basename(image_path) @@ -511,10 +514,9 @@ def run_inference_with_iterative_prompting( assert os.path.exists(gt_path), gt_path image = imageio.imread(image_path) - gt = imageio.imread(gt_path).astype("uint32") + gt = imageio.imread(gt_path) gt = relabel_sequential(gt)[0] - with torch.no_grad(): - _run_inference_with_iterative_prompting_for_image( - model, image, gt, n_iterations, device, use_boxes, prediction_paths, batch_size, - ) + _run_inference_with_iterative_prompting_for_image( + model, image, gt, n_iterations, device, use_boxes, prediction_paths, batch_size, + )