Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Embedding widget for napari plugin #235

Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
21829d5
Start on embedding widget
GenevieveBuckley Oct 17, 2023
ac0412e
Add missing typing import in visualization.py file
GenevieveBuckley Oct 17, 2023
d8eb08b
Fix npe2 napari manifest validation
GenevieveBuckley Oct 17, 2023
4a2ed76
Graceful error handling for torch device backend selection
GenevieveBuckley Oct 19, 2023
38734f0
Embedding widget for napari plugin
GenevieveBuckley Oct 19, 2023
9bbf661
Merge branch 'dev' into embedding-widget
GenevieveBuckley Oct 19, 2023
f0b3f49
Fix _available_devices utility function
GenevieveBuckley Oct 19, 2023
c18af45
Merge branch 'embedding-widget' of github.com:GenevieveBuckley/micro-…
GenevieveBuckley Oct 19, 2023
45c1a57
More clear variable names
GenevieveBuckley Oct 19, 2023
b9c37f0
Add GUI test for embedding widget
GenevieveBuckley Oct 19, 2023
3405039
Thread worker function doesn't actually return object the same way si…
GenevieveBuckley Oct 19, 2023
0314ef2
Fix rgb ndim calculation
GenevieveBuckley Oct 19, 2023
522b27b
Move location where global IMAGE_EMBEDDINGS is defined
GenevieveBuckley Oct 19, 2023
b39d5c9
Order of output files does not matter for test
GenevieveBuckley Oct 23, 2023
174a9c6
Fix ndim for rgb images in embedding_widget
GenevieveBuckley Oct 23, 2023
147797c
Let's be careful since now the prgress bar is an unnamed argument to …
GenevieveBuckley Oct 23, 2023
9b1730f
Match progress bar with thread worker example from napari/examples
GenevieveBuckley Oct 23, 2023
5c573b9
Sanitize user string input to _get_device()
GenevieveBuckley Oct 24, 2023
464d7c3
Merge branch 'dev' into embedding-widget-singleton
GenevieveBuckley Nov 3, 2023
ae9b611
Workaround for issue 246
GenevieveBuckley Nov 3, 2023
c1fe954
Image embedding widget now uses singleton AnnotatorState
GenevieveBuckley Nov 3, 2023
8e75871
Embedding widget, ensure save directory exists and is empty
GenevieveBuckley Nov 3, 2023
1de747e
Don't set a device for custom checkpoint export
constantinpape Nov 4, 2023
1c68c03
More consise code with os.makedirs exist_ok
GenevieveBuckley Nov 5, 2023
856da60
Merge branch 'embedding-widget' of github.com:GenevieveBuckley/micro-…
GenevieveBuckley Nov 5, 2023
bfa8c57
Upgrade invalid embeddings path from user warning to runtime error (s…
GenevieveBuckley Nov 7, 2023
51e96bf
Move all computation into thread worker, allow previously computed em…
GenevieveBuckley Nov 7, 2023
c0c83a4
Add reset_state method to clear all attributes held in state
GenevieveBuckley Nov 7, 2023
a4ad1dc
Remove data_signature attribute from AnnotatorState attributes
GenevieveBuckley Nov 7, 2023
ca3c77e
Embedding widget, image_shape in annotator state
GenevieveBuckley Nov 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions micro_sam/napari.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ contributions:
- id: micro-sam.sample_data_segmentation
python_name: micro_sam.sample_data:sample_data_segmentation
title: Load segmentation sample data from micro-sam plugin
- id: micro-sam.embedding_widget
python_name: micro_sam.sam_annotator._widgets:embedding_widget
title: Embedding widget
sample_data:
- command: micro-sam.sample_data_image_series
display_name: Image series example data
Expand All @@ -45,3 +48,6 @@ contributions:
- command: micro-sam.sample_data_segmentation
display_name: Segmentation sample dataset
key: micro-sam-segmentation
widgets:
- command: micro-sam.embedding_widget
display_name: Embedding widget
65 changes: 65 additions & 0 deletions micro_sam/sam_annotator/_widgets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, Optional

from magicgui import magic_factory
from magicgui.tqdm import tqdm
from napari.qt.threading import thread_worker

from micro_sam.util import (
ImageEmbeddings,
get_sam_model,
precompute_image_embeddings,
_MODEL_URLS,
_DEFAULT_MODEL,
_available_devices,
)

if TYPE_CHECKING:
import napari

Model = Enum("Model", _MODEL_URLS)
available_devices_list = ["auto"] + _available_devices()


@magic_factory(call_button="Compute image embeddings",
device = {"choices": available_devices_list},
save_path={"mode": "d"}, # choose a directory
)
def embedding_widget(
image: "napari.layers.Image",
model: Model = Model.__getitem__(_DEFAULT_MODEL),
device = "auto",
save_path: Optional[Path] = None, # where embeddings for this image are cached (optional)
GenevieveBuckley marked this conversation as resolved.
Show resolved Hide resolved
optional_custom_weights: Optional[Path] = None, # A filepath or URL to custom model weights.
) -> ImageEmbeddings:
"""Image embedding widget."""
# for access to the predictor and the image embeddings in the widgets
global PREDICTOR, IMAGE_EMBEDDINGS
# Initialize the model
PREDICTOR = get_sam_model(device=device, model_type=model.name,
checkpoint_path=optional_custom_weights)

# Get image dimensions
if not image.rgb:
ndim = image.data.ndim
else:
# assumes RGB channels are the last dimension
ndim = image.data.ndim
GenevieveBuckley marked this conversation as resolved.
Show resolved Hide resolved

# Compute the embeddings for the image data
with tqdm() as pbar:
@thread_worker(connect={"finished": lambda: pbar._get_progressbar().hide()})
def _compute_image_embedding(PREDICTOR, image_data, save_path, ndim=None):
if save_path is not None:
save_path = str(save_path)
IMAGE_EMBEDDINGS = precompute_image_embeddings(
GenevieveBuckley marked this conversation as resolved.
Show resolved Hide resolved
predictor = PREDICTOR,
input_ = image_data,
save_path = save_path,
ndim=ndim,
)

worker = _compute_image_embedding(PREDICTOR, image.data, save_path, ndim=ndim)

return worker
35 changes: 31 additions & 4 deletions micro_sam/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,7 @@ def _get_checkpoint(model_type, checkpoint_path=None):
return checkpoint_path


def _get_device(device):
if device is not None:
return device

def _get_default_device():
# Use cuda enabled gpu if it's available.
if torch.cuda.is_available():
device = "cuda"
Expand All @@ -145,6 +142,36 @@ def _get_device(device):
return device


def _get_device(device=None):
GenevieveBuckley marked this conversation as resolved.
Show resolved Hide resolved
if device is None or device == "auto":
device = _get_default_device()
else:
if device == "cuda":
if not torch.cuda.is_available():
raise RuntimeError("PyTorch CUDA backend is not available.")
elif device == "mps":
if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()):
raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.")
elif device == "cpu":
pass # cpu is always available
else:
raise RuntimeError(f"Unsupported device: {device}\n"
"Please choose from 'cpu', 'cuda', or 'mps'.")
return device


def _available_devices():
available_devices = []
for i in ["cuda", "mps", "cpu"]:
try:
device = _get_device(i)
except RuntimeError:
pass
else:
available_devices.append(device)
return available_devices


def get_sam_model(
device: Optional[str] = None,
model_type: str = _DEFAULT_MODEL,
Expand Down
2 changes: 0 additions & 2 deletions micro_sam/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
"""
from typing import Tuple

from typing import Tuple
GenevieveBuckley marked this conversation as resolved.
Show resolved Hide resolved

import numpy as np

from elf.segmentation.embeddings import embedding_pca
Expand Down
31 changes: 31 additions & 0 deletions test/test_sam_annotator/test_widgets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import json
import os

import zarr

from micro_sam.sam_annotator._widgets import embedding_widget, Model
from micro_sam.util import _compute_data_signature


# make_napari_viewer is a pytest fixture that returns a napari viewer object
# you don't need to import it, as long as napari is installed
# in your testing environment.
# tmp_path is a regular pytest fixture.
def test_embedding_widget(make_napari_viewer, tmp_path):
constantinpape marked this conversation as resolved.
Show resolved Hide resolved
"""Test embedding widget for micro-sam napari plugin."""
# setup
viewer = make_napari_viewer()
layer = viewer.open_sample('napari', 'camera')[0]
my_widget = embedding_widget()
# run image embedding widget
worker = my_widget(layer, model=Model.vit_t, device="cpu", save_path=tmp_path)
worker.await_workers() # blocks until thread worker is finished the embedding
# Open embedding results and check they are as expected
assert os.listdir(tmp_path) == ['.zattrs', '.zgroup', 'features']
with open(os.path.join(tmp_path, ".zattrs")) as f:
content = f.read()
zarr_dict = json.loads(content)
assert zarr_dict.get("original_size") == list(layer.data.shape)
assert zarr_dict.get("data_signature") == _compute_data_signature(layer.data)
assert zarr.open(os.path.join(tmp_path, "features")).shape == (1, 256, 64, 64)
viewer.close() # must close the viewer at the end of tests
Loading