Skip to content

Commit

Permalink
Update evaluation functions for optional verbose (#585)
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 authored May 5, 2024
1 parent 7e460ab commit b8861d5
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 15 deletions.
5 changes: 1 addition & 4 deletions micro_sam/evaluation/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,11 @@ def run_evaluation_for_iterative_prompting(
list_of_results = []
prediction_folders = sorted(glob(os.path.join(prediction_root, "iteration*")))
for pred_folder in prediction_folders:
print("Evaluating", pred_folder)
print("Evaluating", os.path.split(pred_folder)[-1])
pred_paths = sorted(glob(os.path.join(pred_folder, "*")))
result = run_evaluation(gt_paths=gt_paths, prediction_paths=pred_paths, save_path=None)
list_of_results.append(result)
print(result)

res_df = pd.concat(list_of_results, ignore_index=True)
res_df.to_csv(csv_path)


# TODO function to evaluate full experiment and resave in one table
23 changes: 17 additions & 6 deletions micro_sam/evaluation/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,8 @@ def _run_inference_with_iterative_prompting_for_image(
prediction_paths,
use_masks=False
) -> None:
verbose_embeddings = False

prompt_generator = IterativePromptGenerator()

gt_ids = np.unique(gt)[1:]
Expand Down Expand Up @@ -426,10 +428,17 @@ def _run_inference_with_iterative_prompting_for_image(
logits_masks = None

batched_outputs = batched_inference(
predictor, image, batch_size,
boxes=boxes, points=points, point_labels=point_labels,
multimasking=multimasking, embedding_path=embedding_path,
return_instance_segmentation=False, logits_masks=logits_masks
predictor=predictor,
image=image,
batch_size=batch_size,
boxes=boxes,
points=points,
point_labels=point_labels,
multimasking=multimasking,
embedding_path=embedding_path,
return_instance_segmentation=False,
logits_masks=logits_masks,
verbose_embeddings=verbose_embeddings,
)

# switching off multimasking after first iter, as next iters (with multiple prompts) don't expect multimasking
Expand Down Expand Up @@ -484,7 +493,7 @@ def run_inference_with_iterative_prompting(
around which points will not be sampled.
batch_size: The batch size used for batched predictions.
n_iterations: The number of iterations for iterative prompting.
use_masks: Whether to make use of logits from previous prompt-based segmentation
use_masks: Whether to make use of logits from previous prompt-based segmentation.
"""
if len(image_paths) != len(gt_paths):
raise ValueError(f"Expect same number of images and gt images, got {len(image_paths)}, {len(gt_paths)}")
Expand All @@ -497,7 +506,9 @@ def run_inference_with_iterative_prompting(
print("The iterative prompting will make use of logits masks from previous iterations.")

for image_path, gt_path in tqdm(
zip(image_paths, gt_paths), total=len(image_paths), desc="Run inference with iterative prompting for all images"
zip(image_paths, gt_paths),
total=len(image_paths),
desc="Run inference with iterative prompting for all images",
):
image_name = os.path.basename(image_path)

Expand Down
15 changes: 12 additions & 3 deletions micro_sam/evaluation/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,13 +181,15 @@ def run_instance_segmentation_grid_search(
result_dir: Folder to cache the evaluation results per image.
embedding_dir: Folder to cache the image embeddings.
fixed_generate_kwargs: Fixed keyword arguments for the `generate` method of the segmenter.
verbose_gs: Whether to run the gridsearch for individual images in a verbose mode.
verbose_gs: Whether to run the grid-search for individual images in a verbose mode.
image_key: Key for loading the image data from a more complex file format like HDF5.
If not given a simple image format like tif is assumed.
gt_key: Key for loading the ground-truth data from a more complex file format like HDF5.
If not given a simple image format like tif is assumed.
rois: Region of interests to resetrict the evaluation to.
"""
verbose_embeddings = False

assert len(image_paths) == len(gt_paths)
fixed_generate_kwargs = {} if fixed_generate_kwargs is None else fixed_generate_kwargs

Expand Down Expand Up @@ -229,7 +231,9 @@ def run_instance_segmentation_grid_search(
else:
assert predictor is not None
embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
image_embeddings = util.precompute_image_embeddings(predictor, image, embedding_path, ndim=2)
image_embeddings = util.precompute_image_embeddings(
predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings
)
segmenter.initialize(image, image_embeddings)

_grid_search_iteration(
Expand All @@ -255,6 +259,8 @@ def run_instance_segmentation_inference(
generate_kwargs: The keyword arguments for the `generate` method of the segmenter.
"""

verbose_embeddings = False

generate_kwargs = {} if generate_kwargs is None else generate_kwargs
predictor = segmenter._predictor
min_object_size = generate_kwargs.get("min_mask_region_area", 0)
Expand All @@ -271,7 +277,9 @@ def run_instance_segmentation_inference(
image = imageio.imread(image_path)

embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
image_embeddings = util.precompute_image_embeddings(predictor, image, embedding_path, ndim=2)
image_embeddings = util.precompute_image_embeddings(
predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings
)

segmenter.initialize(image, image_embeddings)
masks = segmenter.generate(**generate_kwargs)
Expand Down Expand Up @@ -384,6 +392,7 @@ def run_instance_segmentation_grid_search_and_inference(
best_kwargs, best_msa = evaluate_instance_segmentation_grid_search(result_dir, list(grid_search_values.keys()))
best_param_str = ", ".join(f"{k} = {v}" for k, v in best_kwargs.items())
print("Best grid-search result:", best_msa, "with parmeters:\n", best_param_str)
print()

save_grid_search_best_params(best_kwargs, best_msa, Path(embedding_dir).parent)

Expand Down
8 changes: 6 additions & 2 deletions micro_sam/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def batched_inference(
return_instance_segmentation: bool = True,
segmentation_ids: Optional[list] = None,
reduce_multimasking: bool = True,
logits_masks: Optional[torch.Tensor] = None
logits_masks: Optional[torch.Tensor] = None,
verbose_embeddings: bool = True,
):
"""Run batched inference for input prompts.
Expand All @@ -51,6 +52,7 @@ def batched_inference(
highest ious from multimasking
logits_masks: The logits masks. Array of shape N_PROMPTS x 1 x 256 x 256.
Whether to use the logits masks from previous segmentation.
verbose_embeddings: Whether to show progress outputs of computing image embeddings.
Returns:
The predicted segmentation masks.
Expand Down Expand Up @@ -103,7 +105,9 @@ def batched_inference(
)

# Compute the image embeddings.
image_embeddings = util.precompute_image_embeddings(predictor, image, embedding_path, ndim=2)
image_embeddings = util.precompute_image_embeddings(
predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings
)
util.set_precomputed(predictor, image_embeddings)

# Determine the number of batches.
Expand Down

0 comments on commit b8861d5

Please sign in to comment.