Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Embedding widget for napari plugin #235

Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
21829d5
Start on embedding widget
GenevieveBuckley Oct 17, 2023
ac0412e
Add missing typing import in visualization.py file
GenevieveBuckley Oct 17, 2023
d8eb08b
Fix npe2 napari manifest validation
GenevieveBuckley Oct 17, 2023
4a2ed76
Graceful error handling for torch device backend selection
GenevieveBuckley Oct 19, 2023
38734f0
Embedding widget for napari plugin
GenevieveBuckley Oct 19, 2023
9bbf661
Merge branch 'dev' into embedding-widget
GenevieveBuckley Oct 19, 2023
f0b3f49
Fix _available_devices utility function
GenevieveBuckley Oct 19, 2023
c18af45
Merge branch 'embedding-widget' of github.com:GenevieveBuckley/micro-…
GenevieveBuckley Oct 19, 2023
45c1a57
More clear variable names
GenevieveBuckley Oct 19, 2023
b9c37f0
Add GUI test for embedding widget
GenevieveBuckley Oct 19, 2023
3405039
Thread worker function doesn't actually return object the same way si…
GenevieveBuckley Oct 19, 2023
0314ef2
Fix rgb ndim calculation
GenevieveBuckley Oct 19, 2023
522b27b
Move location where global IMAGE_EMBEDDINGS is defined
GenevieveBuckley Oct 19, 2023
b39d5c9
Order of output files does not matter for test
GenevieveBuckley Oct 23, 2023
174a9c6
Fix ndim for rgb images in embedding_widget
GenevieveBuckley Oct 23, 2023
147797c
Let's be careful since now the prgress bar is an unnamed argument to …
GenevieveBuckley Oct 23, 2023
9b1730f
Match progress bar with thread worker example from napari/examples
GenevieveBuckley Oct 23, 2023
5c573b9
Sanitize user string input to _get_device()
GenevieveBuckley Oct 24, 2023
464d7c3
Merge branch 'dev' into embedding-widget-singleton
GenevieveBuckley Nov 3, 2023
ae9b611
Workaround for issue 246
GenevieveBuckley Nov 3, 2023
c1fe954
Image embedding widget now uses singleton AnnotatorState
GenevieveBuckley Nov 3, 2023
8e75871
Embedding widget, ensure save directory exists and is empty
GenevieveBuckley Nov 3, 2023
1de747e
Don't set a device for custom checkpoint export
constantinpape Nov 4, 2023
1c68c03
More consise code with os.makedirs exist_ok
GenevieveBuckley Nov 5, 2023
856da60
Merge branch 'embedding-widget' of github.com:GenevieveBuckley/micro-…
GenevieveBuckley Nov 5, 2023
bfa8c57
Upgrade invalid embeddings path from user warning to runtime error (s…
GenevieveBuckley Nov 7, 2023
51e96bf
Move all computation into thread worker, allow previously computed em…
GenevieveBuckley Nov 7, 2023
c0c83a4
Add reset_state method to clear all attributes held in state
GenevieveBuckley Nov 7, 2023
a4ad1dc
Remove data_signature attribute from AnnotatorState attributes
GenevieveBuckley Nov 7, 2023
ca3c77e
Embedding widget, image_shape in annotator state
GenevieveBuckley Nov 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions micro_sam/napari.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ contributions:
- id: micro-sam.sample_data_segmentation
python_name: micro_sam.sample_data:sample_data_segmentation
title: Load segmentation sample data from micro-sam plugin
- id: micro-sam.embedding_widget
python_name: micro_sam.sam_annotator._widgets:embedding_widget
title: Embedding widget
sample_data:
- command: micro-sam.sample_data_image_series
display_name: Image series example data
Expand All @@ -45,3 +48,6 @@ contributions:
- command: micro-sam.sample_data_segmentation
display_name: Segmentation sample dataset
key: micro-sam-segmentation
widgets:
- command: micro-sam.embedding_widget
display_name: Embedding widget
76 changes: 76 additions & 0 deletions micro_sam/sam_annotator/_widgets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from enum import Enum
import os
from pathlib import Path
from typing import TYPE_CHECKING, Optional

from magicgui import magic_factory, widgets
from napari.qt.threading import thread_worker

from micro_sam.sam_annotator._state import AnnotatorState
from micro_sam.util import (
ImageEmbeddings,
get_sam_model,
precompute_image_embeddings,
_MODEL_URLS,
_DEFAULT_MODEL,
_available_devices,
)

if TYPE_CHECKING:
import napari

Check warning on line 20 in micro_sam/sam_annotator/_widgets.py

View check run for this annotation

Codecov / codecov/patch

micro_sam/sam_annotator/_widgets.py#L20

Added line #L20 was not covered by tests

Model = Enum("Model", _MODEL_URLS)
available_devices_list = ["auto"] + _available_devices()


@magic_factory(
pbar={'visible': False, 'max': 0, 'value': 0, 'label': 'working...'},
call_button="Compute image embeddings",
device = {"choices": available_devices_list},
save_path={"mode": "d"}, # choose a directory
)
def embedding_widget(
pbar: widgets.ProgressBar,
image: "napari.layers.Image",
model: Model = Model.__getitem__(_DEFAULT_MODEL),
device = "auto",
save_path: Optional[Path] = None, # where embeddings for this image are cached (optional)
GenevieveBuckley marked this conversation as resolved.
Show resolved Hide resolved
optional_custom_weights: Optional[Path] = None, # A filepath or URL to custom model weights.
) -> ImageEmbeddings:
"""Image embedding widget."""
# Make sure save directory exists and is an empty directory
if save_path is not None:
if not save_path.exists():
os.makedirs(save_path)

Check warning on line 44 in micro_sam/sam_annotator/_widgets.py

View check run for this annotation

Codecov / codecov/patch

micro_sam/sam_annotator/_widgets.py#L44

Added line #L44 was not covered by tests
GenevieveBuckley marked this conversation as resolved.
Show resolved Hide resolved
if not save_path.is_dir():
raise NotADirectoryError(

Check warning on line 46 in micro_sam/sam_annotator/_widgets.py

View check run for this annotation

Codecov / codecov/patch

micro_sam/sam_annotator/_widgets.py#L46

Added line #L46 was not covered by tests
f"The user selected 'save_path' is not a direcotry: {save_path}"
)
if len(os.listdir(save_path)) > 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this defeats the purpose of caching the affinities: the idea is that a user can close the annotation tool and open it again later while pointing to the same save_path so that the embeddings don't have to be recomputed.
This check prevents this and makes caching the embeddings not useful.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also relates to your comment from above:

I've put in a bit of code to make sure the save_path directory exists and is empty. It prints an error message to the terminal if the directory is not empty.

Printing the error here is not a good idea. Maybe you intended this because of the problem with the hash comparison, but I don't want to print an error for the correct behavior of using already cached affinities.
We could use a warning instead for now and remove this once we figure out the issue with the hash.

Copy link
Collaborator Author

@GenevieveBuckley GenevieveBuckley Nov 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's reasonable - I've changed the code to fix this problem.
Now it works in all situations:

  1. Creating new image embeddings,
  2. Returning/loading existing image embeddings if the data_signature of the save directory matches our input image, and
  3. Making an error popup in the napari viewer if something goes wrong (data signature does not match, save directory is not a zarr array or empty folder, anything else unexpected that raises an actual error).

raise RuntimeError(

Check warning on line 50 in micro_sam/sam_annotator/_widgets.py

View check run for this annotation

Codecov / codecov/patch

micro_sam/sam_annotator/_widgets.py#L50

Added line #L50 was not covered by tests
f"The user selected 'save_path' is not empty: {save_path}"
)

state = AnnotatorState()
# Initialize the model
state.predictor = get_sam_model(device=device, model_type=model.name,
checkpoint_path=optional_custom_weights)
# Get image dimensions
ndim = image.data.ndim
if image.rgb:
ndim -= 1

Check warning on line 61 in micro_sam/sam_annotator/_widgets.py

View check run for this annotation

Codecov / codecov/patch

micro_sam/sam_annotator/_widgets.py#L61

Added line #L61 was not covered by tests

# Compute the image embeddings
@thread_worker(connect={'started': pbar.show, 'returned': pbar.hide})
def _compute_image_embedding(state, image_data, save_path, ndim=None):
if save_path is not None:
save_path = str(save_path)
state.image_embeddings = precompute_image_embeddings(

Check warning on line 68 in micro_sam/sam_annotator/_widgets.py

View check run for this annotation

Codecov / codecov/patch

micro_sam/sam_annotator/_widgets.py#L66-L68

Added lines #L66 - L68 were not covered by tests
predictor = state.predictor,
input_ = image_data,
save_path = save_path,
ndim=ndim,
)
return state.image_embeddings # returns napari._qt.qthreading.FunctionWorker

Check warning on line 74 in micro_sam/sam_annotator/_widgets.py

View check run for this annotation

Codecov / codecov/patch

micro_sam/sam_annotator/_widgets.py#L74

Added line #L74 was not covered by tests

return _compute_image_embedding(state, image.data, save_path, ndim=ndim)
38 changes: 33 additions & 5 deletions micro_sam/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,7 @@
return checkpoint_path


def _get_device(device):
if device is not None:
return device

def _get_default_device():
# Use cuda enabled gpu if it's available.
if torch.cuda.is_available():
device = "cuda"
Expand All @@ -145,6 +142,36 @@
return device


def _get_device(device=None):
GenevieveBuckley marked this conversation as resolved.
Show resolved Hide resolved
if device is None or device == "auto":
device = _get_default_device()
else:
if device.lower() == "cuda":
if not torch.cuda.is_available():
raise RuntimeError("PyTorch CUDA backend is not available.")
elif device.lower() == "mps":
if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()):
raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.")
elif device.lower() == "cpu":
pass # cpu is always available
else:
raise RuntimeError(f"Unsupported device: {device}\n"

Check warning on line 158 in micro_sam/util.py

View check run for this annotation

Codecov / codecov/patch

micro_sam/util.py#L158

Added line #L158 was not covered by tests
"Please choose from 'cpu', 'cuda', or 'mps'.")
return device


def _available_devices():
available_devices = []
for i in ["cuda", "mps", "cpu"]:
try:
device = _get_device(i)
except RuntimeError:
pass
else:
available_devices.append(device)
return available_devices


def get_sam_model(
model_type: str = _DEFAULT_MODEL,
device: Optional[str] = None,
Expand Down Expand Up @@ -258,6 +285,7 @@
checkpoint_path: Union[str, os.PathLike],
model_type: str,
save_path: Union[str, os.PathLike],
device: str = None,
GenevieveBuckley marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
"""Export a finetuned segment anything model to the standard model format.

Expand All @@ -269,7 +297,7 @@
save_path: Where to save the exported model.
"""
_, state = get_custom_sam_model(
checkpoint_path, model_type=model_type, return_state=True, device=torch.device("cpu"),
checkpoint_path, model_type=model_type, return_state=True, device=device,
constantinpape marked this conversation as resolved.
Show resolved Hide resolved
)
model_state = state["model_state"]
prefix = "sam."
Expand Down
2 changes: 0 additions & 2 deletions micro_sam/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
"""
from typing import Tuple

from typing import Tuple
GenevieveBuckley marked this conversation as resolved.
Show resolved Hide resolved

import numpy as np

from elf.segmentation.embeddings import embedding_pca
Expand Down
46 changes: 46 additions & 0 deletions test/test_sam_annotator/test_widgets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import json
import os

from mobile_sam.predictor import SamPredictor as MobileSamPredictor
from segment_anything.predictor import SamPredictor
import torch
import zarr

from micro_sam.sam_annotator._state import AnnotatorState
from micro_sam.sam_annotator._widgets import embedding_widget, Model
from micro_sam.util import _compute_data_signature


# make_napari_viewer is a pytest fixture that returns a napari viewer object
# you don't need to import it, as long as napari is installed
# in your testing environment.
# tmp_path is a regular pytest fixture.
def test_embedding_widget(make_napari_viewer, tmp_path):
constantinpape marked this conversation as resolved.
Show resolved Hide resolved
"""Test embedding widget for micro-sam napari plugin."""
# setup
viewer = make_napari_viewer()
layer = viewer.open_sample('napari', 'camera')[0]
my_widget = embedding_widget()
# run image embedding widget
worker = my_widget(image=layer, model=Model.vit_t, device="cpu", save_path=tmp_path)
worker.await_workers() # blocks until thread worker is finished the embedding
# Check in-memory state - predictor
assert isinstance(AnnotatorState().predictor, (SamPredictor, MobileSamPredictor))
# Check in-memory state - image embeddings
assert AnnotatorState().image_embeddings is not None
assert 'features' in AnnotatorState().image_embeddings.keys()
assert 'input_size' in AnnotatorState().image_embeddings.keys()
assert 'original_size' in AnnotatorState().image_embeddings.keys()
assert isinstance(AnnotatorState().image_embeddings["features"], torch.Tensor)
assert AnnotatorState().image_embeddings["original_size"] == layer.data.shape
# Check saved embedding results are what we expect to have
temp_path_files = os.listdir(tmp_path)
temp_path_files.sort()
assert temp_path_files == ['.zattrs', '.zgroup', 'features']
with open(os.path.join(tmp_path, ".zattrs")) as f:
content = f.read()
zarr_dict = json.loads(content)
assert zarr_dict.get("original_size") == list(layer.data.shape)
assert zarr_dict.get("data_signature") == _compute_data_signature(layer.data)
assert zarr.open(os.path.join(tmp_path, "features")).shape == (1, 256, 64, 64)
viewer.close() # must close the viewer at the end of tests
5 changes: 3 additions & 2 deletions test/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,14 @@ def _train_model(self, model_type, device):
)
trainer.fit(epochs=1)

def _export_model(self, checkpoint_path, export_path, model_type):
def _export_model(self, checkpoint_path, export_path, model_type, device):
constantinpape marked this conversation as resolved.
Show resolved Hide resolved
from micro_sam.util import export_custom_sam_model

export_custom_sam_model(
checkpoint_path=checkpoint_path,
model_type=model_type,
save_path=export_path,
device=device,
)

def _run_inference_and_check_results(
Expand Down Expand Up @@ -152,7 +153,7 @@ def test_training(self):

# Export the model.
export_path = os.path.join(self.tmp_folder, "exported_model.pth")
self._export_model(checkpoint_path, export_path, model_type)
self._export_model(checkpoint_path, export_path, model_type, device)
self.assertTrue(os.path.exists(export_path))

# Check the model with inference with a single point prompt.
Expand Down
Loading