Skip to content

Commit

Permalink
Merge pull request #278 from computational-cell-analytics/dev
Browse files Browse the repository at this point in the history
Merge dev into master
  • Loading branch information
constantinpape authored Nov 18, 2023
2 parents 146fb0f + c236df3 commit 2a55ec5
Show file tree
Hide file tree
Showing 26 changed files with 1,019 additions and 376 deletions.
104 changes: 58 additions & 46 deletions examples/use_as_library/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import napari

from micro_sam import instance_segmentation, util
from micro_sam.multi_dimensional_segmentation import segment_3d_from_slice


def cell_segmentation():
Expand Down Expand Up @@ -32,36 +33,15 @@ def cell_segmentation():

# Generate the instance segmentation. You can call this again for different values of 'pred_iou_thresh'
# without having to call initialize again.
instances_amg = amg.generate(pred_iou_thresh=0.88)
instances_amg = instance_segmentation.mask_data_to_segmentation(
instances_amg, shape=image.shape, with_background=True
)

# Use the mutex waterhsed based instance segmentation logic.
# Here, we generate initial segmentation masks from the image embeddings, using the mutex watershed algorithm.
# These initial masks are used as prompts for the actual instance segmentation.
# This class uses the same overall design as 'AutomaticMaskGenerator'.

# Create the automatic mask generator class.
amg_mws = instance_segmentation.EmbeddingMaskGenerator(predictor, min_initial_size=10)

# Initialize the mask generator with the image and the pre-computed embeddings.
amg_mws.initialize(image, embeddings, verbose=True)

# Generate the instance segmentation. You can call this again for different values of 'pred_iou_thresh'
# without having to call initialize again.
# NOTE: the main advantage of this method is that it's faster than the original implementation,
# however the quality is not as high as the original instance segmentation quality yet.
instances_mws = amg_mws.generate(pred_iou_thresh=0.88)
instances_mws = instance_segmentation.mask_data_to_segmentation(
instances_mws, shape=image.shape, with_background=True
instances = amg.generate(pred_iou_thresh=0.88)
instances = instance_segmentation.mask_data_to_segmentation(
instances, shape=image.shape, with_background=True
)

# Show the results.
v = napari.Viewer()
v.add_image(image)
v.add_labels(instances_amg)
v.add_labels(instances_mws)
v.add_labels(instances)
napari.run()


Expand Down Expand Up @@ -94,39 +74,71 @@ def cell_segmentation_with_tiling():

# Generate the instance segmentation. You can call this again for different values of 'pred_iou_thresh'
# without having to call initialize again.
instances_amg = amg.generate(pred_iou_thresh=0.88)
instances_amg = instance_segmentation.mask_data_to_segmentation(
instances_amg, shape=image.shape, with_background=True
instances = amg.generate(pred_iou_thresh=0.88)
instances = instance_segmentation.mask_data_to_segmentation(
instances, shape=image.shape, with_background=True
)

# Use the mutex waterhsed based instance segmentation logic.
# Here, we generate initial segmentation masks from the image embeddings, using the mutex watershed algorithm.
# These initial masks are used as prompts for the actual instance segmentation.
# This class uses the same overall design as 'AutomaticMaskGenerator'.

# Create the automatic mask generator class.
amg_mws = instance_segmentation.TiledEmbeddingMaskGenerator(predictor, min_initial_size=10)
# Show the results.
v = napari.Viewer()
v.add_image(image)
v.add_labels(instances)
v.add_labels(instances)
napari.run()

# Initialize the mask generator with the image and the pre-computed embeddings.
amg_mws.initialize(image, embeddings, verbose=True)

# Generate the instance segmentation. You can call this again for different values of 'pred_iou_thresh'
# without having to call initialize again.
# NOTE: the main advantage of this method is that it's faster than the original implementation.
# however the quality is not as high as the original instance segmentation quality yet.
instances_mws = amg_mws.generate(pred_iou_thresh=0.88)
def segmentation_in_3d():
"""Run instance segmentation in 3d, for segmenting all objects that intersect
with a given slice. If you use a fine-tuned model for this then you should
first find good parameters for 2d segmentation.
"""
import imageio.v3 as imageio
from micro_sam.sample_data import fetch_nucleus_3d_example_data

# Load the example image data: 3d nucleus segmentation.
path = fetch_nucleus_3d_example_data("./data")
data = imageio.imread(path)

# Load the SAM model for prediction.
model_type = "vit_b" # The model-type to use: vit_h, vit_l, vit_b etc.
checkpoint_path = None # You can specifiy the path to a custom (fine-tuned) model here.
predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path)

# Run 3d segmentation for a given slice. Will segment all objects found in that slice
# throughout the volume.

# The slice that is used for segmentation in 2d. If you don't specify a slice
# then the middle slice is used.
z_slice = data.shape[0] // 2

# The threshold for filtering objects in the 2d segmentation based on the model's
# predicted iou score. If you use a custom model you should first find a good setting
# for this value, e.g. with the 2d annotation tool.
pred_iou_thresh = 0.88

# The threshold for filtering objects in the 2d segmentation based on the model's
# stability score for a given object. If you use a custom model you should first find a good setting
# for this value, e.g. with the 2d annotation tool.
stability_score_thresh = 0.95

instances = segment_3d_from_slice(
predictor, data, z=z_slice,
pred_iou_thresh=pred_iou_thresh,
stability_score_thresh=stability_score_thresh,
verbose=True
)

# Show the results.
v = napari.Viewer()
v.add_image(image)
v.add_labels(instances_amg)
v.add_labels(instances_mws)
v.add_image(data)
v.add_labels(instances)
napari.run()


def main():
cell_segmentation()
# cell_segmentation()
# cell_segmentation_with_tiling()
segmentation_in_3d()


if __name__ == "__main__":
Expand Down
11 changes: 11 additions & 0 deletions finetuning/livecell/evaluation/iterative.sbatch
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#! /bin/bash
#SBATCH -c 16
#SBATCH --mem 48G
#SBATCH -t 6:00:00
#SBATCH -p grete:shared
#SBATCH -G A100:1
#SBATCH -A nim00007
#SBATCH --job-name=sam-iterative-prompting

source activate sam
python iterative_prompting.py $@
85 changes: 68 additions & 17 deletions finetuning/livecell/evaluation/iterative_prompting.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,39 @@
import os
import pandas as pd
from glob import glob

from micro_sam.evaluation.inference import run_inference_with_iterative_prompting
from micro_sam.evaluation import inference
from micro_sam.evaluation.evaluation import run_evaluation

from util import get_checkpoint, get_paths
from util import get_paths, get_checkpoint, MODELS

LIVECELL_GT_ROOT = "/scratch-grete/projects/nim00007/data/LiveCELL/annotations_corrected/livecell_test_images"
# TODO update to make fit other models
PREDICTION_ROOT = "./pred_interactive_prompting"
LIVECELL_GT_ROOT = "/scratch/projects/nim00007/data/LiveCELL/annotations_corrected/livecell_test_images"
PREDICTION_ROOT = "/scratch/projects/nim00007/sam/iterative_evaluation"


def run_interactive_prompting():
prediction_root = PREDICTION_ROOT
def get_prediction_root(start_with_box_prompt, model_description, root_dir=PREDICTION_ROOT):
if start_with_box_prompt:
prediction_root = os.path.join(root_dir, model_description, "start_with_box")
else:
prediction_root = os.path.join(root_dir, model_description, "start_with_point")

return prediction_root


def run_interactive_prompting(predictor, start_with_box_prompt, model_description, prediction_root):
# we organize all the folders with data from this experiment below
embedding_folder = os.path.join(PREDICTION_ROOT, model_description, "embeddings")
os.makedirs(embedding_folder, exist_ok=True)

checkpoint, model_type = get_checkpoint("vit_b")
image_paths, gt_paths = get_paths()

run_inference_with_iterative_prompting(
checkpoint, model_type, image_paths, gt_paths,
prediction_root, use_boxes=False, batch_size=16,
inference.run_inference_with_iterative_prompting(
predictor=predictor,
image_paths=image_paths,
gt_paths=gt_paths,
embedding_dir=embedding_folder,
prediction_dir=prediction_root,
start_with_box_prompt=start_with_box_prompt
)


Expand All @@ -33,20 +47,57 @@ def get_pg_paths(pred_folder):
return pred_paths, gt_paths


def evaluate_interactive_prompting():
prediction_root = PREDICTION_ROOT
def evaluate_interactive_prompting(prediction_root, start_with_box_prompt, model_description):
assert os.path.exists(prediction_root), prediction_root

csv_save_dir = f"./iterative_prompting_results/{model_description}"
os.makedirs(csv_save_dir, exist_ok=True)
csv_path = os.path.join(csv_save_dir, "start_with_box.csv" if start_with_box_prompt else "start_with_point.csv")
if os.path.exists(csv_path):
print("The evaluated results for the expected setting already exist here:", csv_path)
return

prediction_folders = sorted(glob(os.path.join(prediction_root, "iteration*")))
list_of_results = []
for pred_folder in prediction_folders:
print("Evaluating", pred_folder)
pred_paths, gt_paths = get_pg_paths(pred_folder)
res = run_evaluation(gt_paths, pred_paths, save_path=None)
list_of_results.append(res)
print(res)

df = pd.concat(list_of_results, ignore_index=True)
df.to_csv(csv_path)

def main():
# run_interactive_prompting()
evaluate_interactive_prompting()

def main(args):
start_with_box_prompt = args.box # overwrite to start first iters' prompt with box instead of single point
model_description = args.model # overwrite to specify the choice of vanilla / finetuned models

# add the root prediction path where you would like to save the iterative prompting results
prediction_root = get_prediction_root(start_with_box_prompt, model_description)

# get the model checkpoints and desired model name to initialize the predictor
if args.checkpoint is None and model_description in MODELS.keys():
checkpoint, model_type = get_checkpoint(model_description)
else:
checkpoint = args.checkpoint
model_type = model_description[:5]
# get the predictor to perform inference
predictor = inference.get_predictor(checkpoint, model_type)

run_interactive_prompting(predictor, start_with_box_prompt, model_description, prediction_root)
evaluate_interactive_prompting(prediction_root, start_with_box_prompt, model_description)


if __name__ == "__main__":
main()
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--box", action="store_true", help="If passed, starts with first prompt as box")
parser.add_argument(
"-m", "--model", type=str, # options: "vit_h", "vit_h_generalist", "vit_h_specialist"
help="Provide the model type to initialize the predictor"
)
parser.add_argument("-c", "--checkpoint", type=str, default=None)
args = parser.parse_args()
main(args)
14 changes: 7 additions & 7 deletions finetuning/livecell/evaluation/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@

DATA_ROOT = "/scratch/projects/nim00007/data/LiveCELL"
EXPERIMENT_ROOT = "/scratch/projects/nim00007/sam/experiments/livecell"
PROMPT_FOLDER = "/scratch-grete/projects/nim00007/sam/experiments/prompts/livecell"
PROMPT_FOLDER = "/scratch/projects/nim00007/sam/experiments/prompts/livecell"
MODELS = {
"vit_b": "/scratch-grete/projects/nim00007/sam/vanilla/sam_vit_b_01ec64.pth",
"vit_h": "/scratch-grete/projects/nim00007/sam/vanilla/sam_vit_h_4b8939.pth",
"vit_b_specialist": "/scratch-grete/projects/nim00007/sam/LM/LiveCELL/vit_b/best.pt",
"vit_h_specialist": "/scratch-grete/projects/nim00007/sam/LM/LiveCELL/vit_h/best.pt",
"vit_b_generalist": "/scratch-grete/projects/nim00007/sam/LM/generalist/vit_b/best.pt",
"vit_h_generalist": "/scratch-grete/projects/nim00007/sam/LM/generalist/vit_h/best.pt",
"vit_b": "/scratch/projects/nim00007/sam/vanilla/sam_vit_b_01ec64.pth",
"vit_h": "/scratch/projects/nim00007/sam/vanilla/sam_vit_h_4b8939.pth",
"vit_b_specialist": "/scratch/projects/nim00007/sam/models/LM/LiveCELL/vit_b/best.pt",
"vit_h_specialist": "/scratch/projects/nim00007/sam/models/LM/LiveCELL/vit_h/best.pt",
"vit_b_generalist": "/scratch/projects/nim00007/sam/models/LM/generalist/v2/vit_b/best.pt",
"vit_h_generalist": "/scratch/projects/nim00007/sam/models/LM/generalist/v2/vit_h/best.pt",
}


Expand Down
7 changes: 4 additions & 3 deletions finetuning/livecell_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def finetune_livecell(args):
n_objects_per_batch = 25 # this is the number of objects per batch that will be sampled

# get the trainable segment anything model
model = sam_training.get_trainable_sam_model(model_type, checkpoint_path, device=device)
model = sam_training.get_trainable_sam_model(model_type=model_type, device=device, checkpoint_path=checkpoint_path)

# all the stuff we need for training
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
Expand Down Expand Up @@ -72,7 +72,8 @@ def finetune_livecell(args):
convert_inputs=convert_inputs,
n_objects_per_batch=n_objects_per_batch,
n_sub_iteration=8,
compile_model=False
compile_model=False,
mask_prob=0.5 # (optional) overwrite to provide the probability of using mask inputs while training
)
trainer.fit(args.iterations)
if args.export_path is not None:
Expand All @@ -89,7 +90,7 @@ def finetune_livecell(args):
def main():
parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LiveCELL dataset.")
parser.add_argument(
"--input_path", "-i", default="",
"--input_path", "-i", default="/scratch/projects/nim00007/data/LiveCELL/",
help="The filepath to the LiveCELL data. If the data does not exist yet it will be downloaded."
)
parser.add_argument(
Expand Down
Loading

0 comments on commit 2a55ec5

Please sign in to comment.