From 21829d52ac3570f84c24f7805a3bee9f4c94c923 Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Tue, 17 Oct 2023 16:10:06 +1100 Subject: [PATCH 01/26] Start on embedding widget --- micro_sam/napari.yaml | 7 +++++ micro_sam/sam_annotator/widgets.py | 45 ++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+) create mode 100644 micro_sam/sam_annotator/widgets.py diff --git a/micro_sam/napari.yaml b/micro_sam/napari.yaml index f6cdc25b..bf5e1c83 100644 --- a/micro_sam/napari.yaml +++ b/micro_sam/napari.yaml @@ -23,6 +23,9 @@ contributions: - id: micro-sam.sample_data_segmentation python_name: micro_sam.sample_data:sample_data_segmentation title: Load segmentation sample data from micro-sam plugin + - id: micro-sam.embedding_widget + python_name: micro_sam.sam_annotator.widgets.embedding_widget + title: Embedding widget sample_data: - command: micro-sam.sample_data_image_series display_name: Image series example data @@ -45,3 +48,7 @@ contributions: - command: micro-sam.sample_data_segmentation display_name: Segmentation sample dataset key: micro-sam-segmentation + widgets: + - command: micro-sam.embedding_widget + display_name: Embedding widget + key: micro-sam-embedding diff --git a/micro_sam/sam_annotator/widgets.py b/micro_sam/sam_annotator/widgets.py new file mode 100644 index 00000000..4180a1d1 --- /dev/null +++ b/micro_sam/sam_annotator/widgets.py @@ -0,0 +1,45 @@ +from enum import Enum +from pathlib import Path + +from magicgui import magicgui +from napari.types import ImageData +from superqt.utils import thread_worker + +from magicgui import magicgui +from magicgui.tqdm import tqdm + +from micro_sam.util import ( + ImageEmbeddings, + get_sam_model, + precompute_image_embeddings, + _MODEL_URLS, + _DEFAULT_MODEL, +) + + +Model = Enum("Model", _MODEL_URLS) + + +@magicgui(call_button="Compute image embedding") +def embedding_widget( + image: ImageData, + model: Model = Model.__getitem__(_DEFAULT_MODEL), + save_path: Path | str | None = None, + ) -> ImageEmbeddings: + """Image embedding widget.""" + PREDICTOR = get_sam_model(model_type=model.name) + with tqdm() as pbar: + @thread_worker(connect={"finished": lambda: pbar.progressbar.hide()}) + def _compute_image_embedding(image, model, save_path): + image_embeddings = precompute_image_embeddings( + predictor = PREDICTOR, + input_ = image, + save_path = str(save_path), + ) + return image_embeddings + + print("Computing image embedding...") + image_embeddings = _compute_image_embedding(image, model, save_path) + print("Finished image embedding computation.") + + return image_embeddings From ac0412efcb48a250c7ee03449a0f3ca730889a8c Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Tue, 17 Oct 2023 16:16:12 +1100 Subject: [PATCH 02/26] Add missing typing import in visualization.py file --- micro_sam/visualization.py | 1 + 1 file changed, 1 insertion(+) diff --git a/micro_sam/visualization.py b/micro_sam/visualization.py index 0426e193..c931985f 100644 --- a/micro_sam/visualization.py +++ b/micro_sam/visualization.py @@ -1,6 +1,7 @@ """ Functionality for visualizing image embeddings. """ +from typing import Tuple import numpy as np From d8eb08b73e485fc64e0691dbcce1d44602da1734 Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Tue, 17 Oct 2023 16:16:29 +1100 Subject: [PATCH 03/26] Fix npe2 napari manifest validation --- micro_sam/napari.yaml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/micro_sam/napari.yaml b/micro_sam/napari.yaml index bf5e1c83..a5557fa1 100644 --- a/micro_sam/napari.yaml +++ b/micro_sam/napari.yaml @@ -24,7 +24,7 @@ contributions: python_name: micro_sam.sample_data:sample_data_segmentation title: Load segmentation sample data from micro-sam plugin - id: micro-sam.embedding_widget - python_name: micro_sam.sam_annotator.widgets.embedding_widget + python_name: micro_sam.sam_annotator.widgets:embedding_widget title: Embedding widget sample_data: - command: micro-sam.sample_data_image_series @@ -51,4 +51,3 @@ contributions: widgets: - command: micro-sam.embedding_widget display_name: Embedding widget - key: micro-sam-embedding From 4a2ed76e5c392b3611ec2706496c2e18d0ac96a4 Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Thu, 19 Oct 2023 14:42:50 +1100 Subject: [PATCH 04/26] Graceful error handling for torch device backend selection --- .../sam_annotator/{widgets.py => _widgets.py} | 0 micro_sam/util.py | 33 ++++++++++++++++--- 2 files changed, 29 insertions(+), 4 deletions(-) rename micro_sam/sam_annotator/{widgets.py => _widgets.py} (100%) diff --git a/micro_sam/sam_annotator/widgets.py b/micro_sam/sam_annotator/_widgets.py similarity index 100% rename from micro_sam/sam_annotator/widgets.py rename to micro_sam/sam_annotator/_widgets.py diff --git a/micro_sam/util.py b/micro_sam/util.py index 4eebd3f1..15125b57 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -126,10 +126,7 @@ def _get_checkpoint(model_type, checkpoint_path=None): return checkpoint_path -def _get_device(device): - if device is not None: - return device - +def _get_default_device(): # Use cuda enabled gpu if it's available. if torch.cuda.is_available(): device = "cuda" @@ -144,6 +141,34 @@ def _get_device(device): return device +def _get_device(device=None): + if device is None or device == "auto": + device = _get_default_device() + else: + if device == "cuda": + assert torch.cuda.is_available() + elif device == "mps": + assert torch.backends.mps.is_available() and torch.backends.mps.is_built() + elif device == "cpu": + pass # cpu is always available + else: + raise RuntimeError(f"Unsupported device: {device}\n" + "Please choose from 'cpu', 'cuda', or 'mps'.") + return device + + +def _available_devices(): + available_devices = [None] + for i in ["cuda", "mps", "cpu"]: + try: + device = _get_device(i) + except RuntimeError: + pass + else: + available_devices.append(device) + return available_devices + + def get_sam_model( device: Optional[str] = None, model_type: str = _DEFAULT_MODEL, From 38734f063422aea51f37028241a5fdcd2f44f3e2 Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Thu, 19 Oct 2023 14:43:34 +1100 Subject: [PATCH 05/26] Embedding widget for napari plugin --- micro_sam/napari.yaml | 2 +- micro_sam/sam_annotator/_widgets.py | 47 ++++++++++++++++++----------- 2 files changed, 30 insertions(+), 19 deletions(-) diff --git a/micro_sam/napari.yaml b/micro_sam/napari.yaml index a5557fa1..e8bd09c7 100644 --- a/micro_sam/napari.yaml +++ b/micro_sam/napari.yaml @@ -24,7 +24,7 @@ contributions: python_name: micro_sam.sample_data:sample_data_segmentation title: Load segmentation sample data from micro-sam plugin - id: micro-sam.embedding_widget - python_name: micro_sam.sam_annotator.widgets:embedding_widget + python_name: micro_sam.sam_annotator._widgets:embedding_widget title: Embedding widget sample_data: - command: micro-sam.sample_data_image_series diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index 4180a1d1..447a757c 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -1,12 +1,10 @@ from enum import Enum from pathlib import Path +from typing import TYPE_CHECKING, Optional -from magicgui import magicgui -from napari.types import ImageData -from superqt.utils import thread_worker - -from magicgui import magicgui +from magicgui import magic_factory from magicgui.tqdm import tqdm +from superqt.utils import thread_worker from micro_sam.util import ( ImageEmbeddings, @@ -16,30 +14,43 @@ _DEFAULT_MODEL, ) +if TYPE_CHECKING: + import napari Model = Enum("Model", _MODEL_URLS) -@magicgui(call_button="Compute image embedding") +@magic_factory(call_button="Compute image embeddings", + device = {"choices": ["auto", 'cuda', 'mps', 'cpu']}) def embedding_widget( - image: ImageData, - model: Model = Model.__getitem__(_DEFAULT_MODEL), - save_path: Path | str | None = None, - ) -> ImageEmbeddings: + image: "napari.layers.Image", + model: Model = Model.__getitem__(_DEFAULT_MODEL), + device = "auto", + save_path: Optional[Path] = None, # where embeddings for this image are cached (optional) + custom_model: Optional[str] = None, # A filepath or URL to custom model weights. +) -> ImageEmbeddings: """Image embedding widget.""" + # for access to the predictor and the image embeddings in the widgets + global PREDICTOR, IMAGE_EMBEDDINGS + # Initialize the model PREDICTOR = get_sam_model(model_type=model.name) + + # Get image dimensions + if not image.rgb: + ndim = image.data.ndim + else: + # assumes RGB channels are the last dimension + ndim = image.data.ndim + + # Compute the embeddings for the image data with tqdm() as pbar: @thread_worker(connect={"finished": lambda: pbar.progressbar.hide()}) - def _compute_image_embedding(image, model, save_path): - image_embeddings = precompute_image_embeddings( + def _compute_image_embedding(PREDICTOR, image, save_path, ndim=None): + IMAGE_EMBEDDINGS = precompute_image_embeddings( predictor = PREDICTOR, input_ = image, save_path = str(save_path), + ndim=ndim, ) - return image_embeddings - - print("Computing image embedding...") - image_embeddings = _compute_image_embedding(image, model, save_path) - print("Finished image embedding computation.") - return image_embeddings + _compute_image_embedding(PREDICTOR, image.data, save_path, ndim=ndim) From f0b3f4908a2081c162937b86e88045a920466168 Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Thu, 19 Oct 2023 15:13:27 +1100 Subject: [PATCH 06/26] Fix _available_devices utility function --- micro_sam/sam_annotator/_widgets.py | 4 +++- micro_sam/util.py | 10 ++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index 447a757c..f083f24d 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -12,16 +12,18 @@ precompute_image_embeddings, _MODEL_URLS, _DEFAULT_MODEL, + _available_devices, ) if TYPE_CHECKING: import napari Model = Enum("Model", _MODEL_URLS) +available_devices_list = ["auto"] + _available_devices() @magic_factory(call_button="Compute image embeddings", - device = {"choices": ["auto", 'cuda', 'mps', 'cpu']}) + device = {"choices": available_devices_list}) def embedding_widget( image: "napari.layers.Image", model: Model = Model.__getitem__(_DEFAULT_MODEL), diff --git a/micro_sam/util.py b/micro_sam/util.py index 15125b57..fb7e83ca 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -146,9 +146,11 @@ def _get_device(device=None): device = _get_default_device() else: if device == "cuda": - assert torch.cuda.is_available() + if not torch.cuda.is_available(): + raise RuntimeError("PyTorch CUDA backend is not available.") elif device == "mps": - assert torch.backends.mps.is_available() and torch.backends.mps.is_built() + if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()): + raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.") elif device == "cpu": pass # cpu is always available else: @@ -158,7 +160,7 @@ def _get_device(device=None): def _available_devices(): - available_devices = [None] + available_devices = [] for i in ["cuda", "mps", "cpu"]: try: device = _get_device(i) @@ -166,7 +168,7 @@ def _available_devices(): pass else: available_devices.append(device) - return available_devices + return available_devices def get_sam_model( From 45c1a57c7428fd601a1d1f33a5c9a73aead31162 Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Thu, 19 Oct 2023 15:25:03 +1100 Subject: [PATCH 07/26] More clear variable names --- micro_sam/sam_annotator/_widgets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index f083f24d..7eaf6288 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -47,10 +47,10 @@ def embedding_widget( # Compute the embeddings for the image data with tqdm() as pbar: @thread_worker(connect={"finished": lambda: pbar.progressbar.hide()}) - def _compute_image_embedding(PREDICTOR, image, save_path, ndim=None): + def _compute_image_embedding(PREDICTOR, image_data, save_path, ndim=None): IMAGE_EMBEDDINGS = precompute_image_embeddings( predictor = PREDICTOR, - input_ = image, + input_ = image_data, save_path = str(save_path), ndim=ndim, ) From b9c37f0d08c647a9d00e8e4d2d4cb07c38aadb99 Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Thu, 19 Oct 2023 18:13:46 +1100 Subject: [PATCH 08/26] Add GUI test for embedding widget --- micro_sam/sam_annotator/_widgets.py | 20 +++++++++++----- test/test_sam_annotator/test_widgets.py | 31 +++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 6 deletions(-) create mode 100644 test/test_sam_annotator/test_widgets.py diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index 7eaf6288..fcf83f7f 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -23,19 +23,22 @@ @magic_factory(call_button="Compute image embeddings", - device = {"choices": available_devices_list}) + device = {"choices": available_devices_list}, + save_path={"mode": "d"}, # choose a directory + ) def embedding_widget( image: "napari.layers.Image", model: Model = Model.__getitem__(_DEFAULT_MODEL), device = "auto", save_path: Optional[Path] = None, # where embeddings for this image are cached (optional) - custom_model: Optional[str] = None, # A filepath or URL to custom model weights. + optional_custom_weights: Optional[Path] = None, # A filepath or URL to custom model weights. ) -> ImageEmbeddings: """Image embedding widget.""" # for access to the predictor and the image embeddings in the widgets global PREDICTOR, IMAGE_EMBEDDINGS # Initialize the model - PREDICTOR = get_sam_model(model_type=model.name) + PREDICTOR = get_sam_model(device=device, model_type=model.name, + checkpoint_path=optional_custom_weights) # Get image dimensions if not image.rgb: @@ -46,13 +49,18 @@ def embedding_widget( # Compute the embeddings for the image data with tqdm() as pbar: - @thread_worker(connect={"finished": lambda: pbar.progressbar.hide()}) + @thread_worker(connect={"finished": lambda: pbar._get_progressbar().hide()}) def _compute_image_embedding(PREDICTOR, image_data, save_path, ndim=None): + if save_path is not None: + save_path = str(save_path) IMAGE_EMBEDDINGS = precompute_image_embeddings( predictor = PREDICTOR, input_ = image_data, - save_path = str(save_path), + save_path = save_path, ndim=ndim, ) + return IMAGE_EMBEDDINGS - _compute_image_embedding(PREDICTOR, image.data, save_path, ndim=ndim) + result = _compute_image_embedding(PREDICTOR, image.data, save_path, ndim=ndim) + + return result diff --git a/test/test_sam_annotator/test_widgets.py b/test/test_sam_annotator/test_widgets.py new file mode 100644 index 00000000..9df588be --- /dev/null +++ b/test/test_sam_annotator/test_widgets.py @@ -0,0 +1,31 @@ +import json +import os + +import zarr + +from micro_sam.sam_annotator._widgets import embedding_widget, Model +from micro_sam.util import _compute_data_signature + + +# make_napari_viewer is a pytest fixture that returns a napari viewer object +# you don't need to import it, as long as napari is installed +# in your testing environment. +# tmp_path is a regular pytest fixture. +def test_embedding_widget(make_napari_viewer, tmp_path): + """Test embedding widget for micro-sam napari plugin.""" + # setup + viewer = make_napari_viewer() + layer = viewer.open_sample('napari', 'camera')[0] + my_widget = embedding_widget() + # run image embedding widget + worker = my_widget(layer, model=Model.vit_t, device="cpu", save_path=tmp_path) + worker.await_workers() # blocks until thread worker is finished the embedding + # Open embedding results and check they are as expected + assert os.listdir(tmp_path) == ['.zattrs', '.zgroup', 'features'] + with open(os.path.join(tmp_path, ".zattrs")) as f: + content = f.read() + zarr_dict = json.loads(content) + assert zarr_dict.get("original_size") == list(layer.data.shape) + assert zarr_dict.get("data_signature") == _compute_data_signature(layer.data) + assert zarr.open(os.path.join(tmp_path, "features")).shape == (1, 256, 64, 64) + viewer.close() # must close the viewer at the end of tests From 3405039eada235bbe16faff12a443f994c56ec95 Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Thu, 19 Oct 2023 18:20:56 +1100 Subject: [PATCH 09/26] Thread worker function doesn't actually return object the same way single-threaded code works --- micro_sam/sam_annotator/_widgets.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index fcf83f7f..c27e301d 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -4,7 +4,7 @@ from magicgui import magic_factory from magicgui.tqdm import tqdm -from superqt.utils import thread_worker +from napari.qt.threading import thread_worker from micro_sam.util import ( ImageEmbeddings, @@ -59,8 +59,7 @@ def _compute_image_embedding(PREDICTOR, image_data, save_path, ndim=None): save_path = save_path, ndim=ndim, ) - return IMAGE_EMBEDDINGS - result = _compute_image_embedding(PREDICTOR, image.data, save_path, ndim=ndim) + worker = _compute_image_embedding(PREDICTOR, image.data, save_path, ndim=ndim) - return result + return worker From 0314ef264eb037ec400f3b5301427a1a7901b92c Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Fri, 20 Oct 2023 10:26:25 +1100 Subject: [PATCH 10/26] Fix rgb ndim calculation --- micro_sam/sam_annotator/_widgets.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index c27e301d..2cd160cb 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -40,12 +40,7 @@ def embedding_widget( PREDICTOR = get_sam_model(device=device, model_type=model.name, checkpoint_path=optional_custom_weights) - # Get image dimensions - if not image.rgb: - ndim = image.data.ndim - else: - # assumes RGB channels are the last dimension - ndim = image.data.ndim + ndim = image.data.ndim # Get image dimensions # Compute the embeddings for the image data with tqdm() as pbar: From 522b27b3302c01868885814e4c050cc8e4d10918 Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Fri, 20 Oct 2023 10:31:48 +1100 Subject: [PATCH 11/26] Move location where global IMAGE_EMBEDDINGS is defined --- micro_sam/sam_annotator/_widgets.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index 2cd160cb..7f91ce78 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -35,7 +35,7 @@ def embedding_widget( ) -> ImageEmbeddings: """Image embedding widget.""" # for access to the predictor and the image embeddings in the widgets - global PREDICTOR, IMAGE_EMBEDDINGS + global PREDICTOR # Initialize the model PREDICTOR = get_sam_model(device=device, model_type=model.name, checkpoint_path=optional_custom_weights) @@ -48,6 +48,7 @@ def embedding_widget( def _compute_image_embedding(PREDICTOR, image_data, save_path, ndim=None): if save_path is not None: save_path = str(save_path) + global IMAGE_EMBEDDINGS IMAGE_EMBEDDINGS = precompute_image_embeddings( predictor = PREDICTOR, input_ = image_data, From b39d5c9bc77a2a2ca74e529a8263a5b1292ec877 Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Mon, 23 Oct 2023 18:40:45 +1100 Subject: [PATCH 12/26] Order of output files does not matter for test --- test/test_sam_annotator/test_widgets.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/test_sam_annotator/test_widgets.py b/test/test_sam_annotator/test_widgets.py index 9df588be..188f5d0a 100644 --- a/test/test_sam_annotator/test_widgets.py +++ b/test/test_sam_annotator/test_widgets.py @@ -21,7 +21,9 @@ def test_embedding_widget(make_napari_viewer, tmp_path): worker = my_widget(layer, model=Model.vit_t, device="cpu", save_path=tmp_path) worker.await_workers() # blocks until thread worker is finished the embedding # Open embedding results and check they are as expected - assert os.listdir(tmp_path) == ['.zattrs', '.zgroup', 'features'] + temp_path_files = os.listdir(tmp_path) + temp_path_files.sort() + assert temp_path_files == ['.zattrs', '.zgroup', 'features'] with open(os.path.join(tmp_path, ".zattrs")) as f: content = f.read() zarr_dict = json.loads(content) From 174a9c6d2b902eea1f47648e0880838fa4ff0932 Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Mon, 23 Oct 2023 19:02:43 +1100 Subject: [PATCH 13/26] Fix ndim for rgb images in embedding_widget --- micro_sam/sam_annotator/_widgets.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index 7f91ce78..8d0bfb0f 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -39,8 +39,10 @@ def embedding_widget( # Initialize the model PREDICTOR = get_sam_model(device=device, model_type=model.name, checkpoint_path=optional_custom_weights) - - ndim = image.data.ndim # Get image dimensions + # Get image dimensions + ndim = image.data.ndim + if image.rgb: + ndim -= 1 # Compute the embeddings for the image data with tqdm() as pbar: From 147797cf30d7416afdb3079fceb08423d0d513de Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Mon, 23 Oct 2023 19:22:43 +1100 Subject: [PATCH 14/26] Let's be careful since now the prgress bar is an unnamed argument to this function --- test/test_sam_annotator/test_widgets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_sam_annotator/test_widgets.py b/test/test_sam_annotator/test_widgets.py index 188f5d0a..ff7c517d 100644 --- a/test/test_sam_annotator/test_widgets.py +++ b/test/test_sam_annotator/test_widgets.py @@ -18,7 +18,7 @@ def test_embedding_widget(make_napari_viewer, tmp_path): layer = viewer.open_sample('napari', 'camera')[0] my_widget = embedding_widget() # run image embedding widget - worker = my_widget(layer, model=Model.vit_t, device="cpu", save_path=tmp_path) + worker = my_widget(image=layer, model=Model.vit_t, device="cpu", save_path=tmp_path) worker.await_workers() # blocks until thread worker is finished the embedding # Open embedding results and check they are as expected temp_path_files = os.listdir(tmp_path) From 9b1730f4967dae3006d066bc803f59640e6ff825 Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Mon, 23 Oct 2023 19:23:56 +1100 Subject: [PATCH 15/26] Match progress bar with thread worker example from napari/examples --- micro_sam/sam_annotator/_widgets.py | 46 ++++++++++++++--------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index 8d0bfb0f..7d1cf970 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -2,8 +2,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Optional -from magicgui import magic_factory -from magicgui.tqdm import tqdm +from magicgui import magic_factory, widgets from napari.qt.threading import thread_worker from micro_sam.util import ( @@ -22,11 +21,14 @@ available_devices_list = ["auto"] + _available_devices() -@magic_factory(call_button="Compute image embeddings", - device = {"choices": available_devices_list}, - save_path={"mode": "d"}, # choose a directory - ) +@magic_factory( + pbar={'visible': False, 'max': 0, 'value': 0, 'label': 'working...'}, + call_button="Compute image embeddings", + device = {"choices": available_devices_list}, + save_path={"mode": "d"}, # choose a directory + ) def embedding_widget( + pbar: widgets.ProgressBar, image: "napari.layers.Image", model: Model = Model.__getitem__(_DEFAULT_MODEL), device = "auto", @@ -44,20 +46,18 @@ def embedding_widget( if image.rgb: ndim -= 1 - # Compute the embeddings for the image data - with tqdm() as pbar: - @thread_worker(connect={"finished": lambda: pbar._get_progressbar().hide()}) - def _compute_image_embedding(PREDICTOR, image_data, save_path, ndim=None): - if save_path is not None: - save_path = str(save_path) - global IMAGE_EMBEDDINGS - IMAGE_EMBEDDINGS = precompute_image_embeddings( - predictor = PREDICTOR, - input_ = image_data, - save_path = save_path, - ndim=ndim, - ) - - worker = _compute_image_embedding(PREDICTOR, image.data, save_path, ndim=ndim) - - return worker + # Compute the image embeddings + @thread_worker(connect={'started': pbar.show, 'returned': pbar.hide}) + def _compute_image_embedding(PREDICTOR, image_data, save_path, ndim=None): + if save_path is not None: + save_path = str(save_path) + global IMAGE_EMBEDDINGS + IMAGE_EMBEDDINGS = precompute_image_embeddings( + predictor = PREDICTOR, + input_ = image_data, + save_path = save_path, + ndim=ndim, + ) + return IMAGE_EMBEDDINGS # returns napari._qt.qthreading.FunctionWorker + + return _compute_image_embedding(PREDICTOR, image.data, save_path, ndim=ndim) From 5c573b96f8021912184386f48f175c374f15937f Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Wed, 25 Oct 2023 10:15:17 +1100 Subject: [PATCH 16/26] Sanitize user string input to _get_device() --- micro_sam/util.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/micro_sam/util.py b/micro_sam/util.py index 3f8abc71..3c88005a 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -146,13 +146,13 @@ def _get_device(device=None): if device is None or device == "auto": device = _get_default_device() else: - if device == "cuda": + if device.lower() == "cuda": if not torch.cuda.is_available(): raise RuntimeError("PyTorch CUDA backend is not available.") - elif device == "mps": + elif device.lower() == "mps": if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()): raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.") - elif device == "cpu": + elif device.lower() == "cpu": pass # cpu is always available else: raise RuntimeError(f"Unsupported device: {device}\n" From ae9b6110fcdff87b333f6417c44494e9b3970613 Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Fri, 3 Nov 2023 15:16:17 +1100 Subject: [PATCH 17/26] Workaround for issue 246 --- micro_sam/util.py | 3 ++- test/test_training.py | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/micro_sam/util.py b/micro_sam/util.py index af755d99..2e3bbffd 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -285,6 +285,7 @@ def export_custom_sam_model( checkpoint_path: Union[str, os.PathLike], model_type: str, save_path: Union[str, os.PathLike], + device: str = None, ) -> None: """Export a finetuned segment anything model to the standard model format. @@ -296,7 +297,7 @@ def export_custom_sam_model( save_path: Where to save the exported model. """ _, state = get_custom_sam_model( - checkpoint_path, model_type=model_type, return_state=True, device=torch.device("cpu"), + checkpoint_path, model_type=model_type, return_state=True, device=device, ) model_state = state["model_state"] prefix = "sam." diff --git a/test/test_training.py b/test/test_training.py index 60a59903..322c37a0 100644 --- a/test/test_training.py +++ b/test/test_training.py @@ -110,13 +110,14 @@ def _train_model(self, model_type, device): ) trainer.fit(epochs=1) - def _export_model(self, checkpoint_path, export_path, model_type): + def _export_model(self, checkpoint_path, export_path, model_type, device): from micro_sam.util import export_custom_sam_model export_custom_sam_model( checkpoint_path=checkpoint_path, model_type=model_type, save_path=export_path, + device=device, ) def _run_inference_and_check_results( @@ -152,7 +153,7 @@ def test_training(self): # Export the model. export_path = os.path.join(self.tmp_folder, "exported_model.pth") - self._export_model(checkpoint_path, export_path, model_type) + self._export_model(checkpoint_path, export_path, model_type, device) self.assertTrue(os.path.exists(export_path)) # Check the model with inference with a single point prompt. From c1fe9546d0e89178e3a179e1c3b0883a327ec9c4 Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Fri, 3 Nov 2023 16:03:58 +1100 Subject: [PATCH 18/26] Image embedding widget now uses singleton AnnotatorState --- micro_sam/sam_annotator/_widgets.py | 17 ++++++++--------- test/test_sam_annotator/test_widgets.py | 15 ++++++++++++++- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index 7d1cf970..831a5096 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -5,6 +5,7 @@ from magicgui import magic_factory, widgets from napari.qt.threading import thread_worker +from micro_sam.sam_annotator._state import AnnotatorState from micro_sam.util import ( ImageEmbeddings, get_sam_model, @@ -36,10 +37,9 @@ def embedding_widget( optional_custom_weights: Optional[Path] = None, # A filepath or URL to custom model weights. ) -> ImageEmbeddings: """Image embedding widget.""" - # for access to the predictor and the image embeddings in the widgets - global PREDICTOR + state = AnnotatorState() # Initialize the model - PREDICTOR = get_sam_model(device=device, model_type=model.name, + state.predictor = get_sam_model(device=device, model_type=model.name, checkpoint_path=optional_custom_weights) # Get image dimensions ndim = image.data.ndim @@ -48,16 +48,15 @@ def embedding_widget( # Compute the image embeddings @thread_worker(connect={'started': pbar.show, 'returned': pbar.hide}) - def _compute_image_embedding(PREDICTOR, image_data, save_path, ndim=None): + def _compute_image_embedding(state, image_data, save_path, ndim=None): if save_path is not None: save_path = str(save_path) - global IMAGE_EMBEDDINGS - IMAGE_EMBEDDINGS = precompute_image_embeddings( - predictor = PREDICTOR, + state.image_embeddings = precompute_image_embeddings( + predictor = state.predictor, input_ = image_data, save_path = save_path, ndim=ndim, ) - return IMAGE_EMBEDDINGS # returns napari._qt.qthreading.FunctionWorker + return state.image_embeddings # returns napari._qt.qthreading.FunctionWorker - return _compute_image_embedding(PREDICTOR, image.data, save_path, ndim=ndim) + return _compute_image_embedding(state, image.data, save_path, ndim=ndim) diff --git a/test/test_sam_annotator/test_widgets.py b/test/test_sam_annotator/test_widgets.py index ff7c517d..dd4adb6f 100644 --- a/test/test_sam_annotator/test_widgets.py +++ b/test/test_sam_annotator/test_widgets.py @@ -1,8 +1,12 @@ import json import os +from mobile_sam.predictor import SamPredictor as MobileSamPredictor +from segment_anything.predictor import SamPredictor +import torch import zarr +from micro_sam.sam_annotator._state import AnnotatorState from micro_sam.sam_annotator._widgets import embedding_widget, Model from micro_sam.util import _compute_data_signature @@ -20,7 +24,16 @@ def test_embedding_widget(make_napari_viewer, tmp_path): # run image embedding widget worker = my_widget(image=layer, model=Model.vit_t, device="cpu", save_path=tmp_path) worker.await_workers() # blocks until thread worker is finished the embedding - # Open embedding results and check they are as expected + # Check in-memory state - predictor + assert isinstance(AnnotatorState().predictor, (SamPredictor, MobileSamPredictor)) + # Check in-memory state - image embeddings + assert AnnotatorState().image_embeddings is not None + assert 'features' in AnnotatorState().image_embeddings.keys() + assert 'input_size' in AnnotatorState().image_embeddings.keys() + assert 'original_size' in AnnotatorState().image_embeddings.keys() + assert isinstance(AnnotatorState().image_embeddings["features"], torch.Tensor) + assert AnnotatorState().image_embeddings["original_size"] == layer.data.shape + # Check saved embedding results are what we expect to have temp_path_files = os.listdir(tmp_path) temp_path_files.sort() assert temp_path_files == ['.zattrs', '.zgroup', 'features'] From 8e75871b0935ecd8e72aebddef31e3f08e4f08e0 Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Fri, 3 Nov 2023 16:59:55 +1100 Subject: [PATCH 19/26] Embedding widget, ensure save directory exists and is empty --- micro_sam/sam_annotator/_widgets.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index 831a5096..a7092b3c 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -1,4 +1,5 @@ from enum import Enum +import os from pathlib import Path from typing import TYPE_CHECKING, Optional @@ -37,6 +38,19 @@ def embedding_widget( optional_custom_weights: Optional[Path] = None, # A filepath or URL to custom model weights. ) -> ImageEmbeddings: """Image embedding widget.""" + # Make sure save directory exists and is an empty directory + if save_path is not None: + if not save_path.exists(): + os.makedirs(save_path) + if not save_path.is_dir(): + raise NotADirectoryError( + f"The user selected 'save_path' is not a direcotry: {save_path}" + ) + if len(os.listdir(save_path)) > 0: + raise RuntimeError( + f"The user selected 'save_path' is not empty: {save_path}" + ) + state = AnnotatorState() # Initialize the model state.predictor = get_sam_model(device=device, model_type=model.name, From 1de747e3c17a7cf1920fab435ba995b58153ed30 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sat, 4 Nov 2023 10:46:28 +0100 Subject: [PATCH 20/26] Don't set a device for custom checkpoint export --- micro_sam/util.py | 3 +-- test/test_training.py | 5 ++--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/micro_sam/util.py b/micro_sam/util.py index 2e3bbffd..b5c45eb9 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -285,7 +285,6 @@ def export_custom_sam_model( checkpoint_path: Union[str, os.PathLike], model_type: str, save_path: Union[str, os.PathLike], - device: str = None, ) -> None: """Export a finetuned segment anything model to the standard model format. @@ -297,7 +296,7 @@ def export_custom_sam_model( save_path: Where to save the exported model. """ _, state = get_custom_sam_model( - checkpoint_path, model_type=model_type, return_state=True, device=device, + checkpoint_path, model_type=model_type, return_state=True, device="cpu", ) model_state = state["model_state"] prefix = "sam." diff --git a/test/test_training.py b/test/test_training.py index 322c37a0..60a59903 100644 --- a/test/test_training.py +++ b/test/test_training.py @@ -110,14 +110,13 @@ def _train_model(self, model_type, device): ) trainer.fit(epochs=1) - def _export_model(self, checkpoint_path, export_path, model_type, device): + def _export_model(self, checkpoint_path, export_path, model_type): from micro_sam.util import export_custom_sam_model export_custom_sam_model( checkpoint_path=checkpoint_path, model_type=model_type, save_path=export_path, - device=device, ) def _run_inference_and_check_results( @@ -153,7 +152,7 @@ def test_training(self): # Export the model. export_path = os.path.join(self.tmp_folder, "exported_model.pth") - self._export_model(checkpoint_path, export_path, model_type, device) + self._export_model(checkpoint_path, export_path, model_type) self.assertTrue(os.path.exists(export_path)) # Check the model with inference with a single point prompt. From 1c68c0362fe36716b5e265e618ff59a536a19677 Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Mon, 6 Nov 2023 10:45:49 +1100 Subject: [PATCH 21/26] More consise code with os.makedirs exist_ok --- micro_sam/sam_annotator/_widgets.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index a7092b3c..7201b7c3 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -40,8 +40,7 @@ def embedding_widget( """Image embedding widget.""" # Make sure save directory exists and is an empty directory if save_path is not None: - if not save_path.exists(): - os.makedirs(save_path) + os.makedirs(save_path, exist_ok=True) if not save_path.is_dir(): raise NotADirectoryError( f"The user selected 'save_path' is not a direcotry: {save_path}" From bfa8c57c10681a0ab98bb551877b8343ae1f2b97 Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Tue, 7 Nov 2023 16:16:37 +1100 Subject: [PATCH 22/26] Upgrade invalid embeddings path from user warning to runtime error (so napari users will see error popup) --- micro_sam/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/micro_sam/util.py b/micro_sam/util.py index b5c45eb9..f6608f37 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -579,7 +579,7 @@ def precompute_image_embeddings( continue # check whether the key signature does not match or is not in the file if key not in f.attrs or f.attrs[key] != val: - warnings.warn( + raise RuntimeError( f"Embeddings file {save_path} is invalid due to unmatching {key}: " f"{f.attrs.get(key)} != {val}.Please recompute embeddings in a new file." ) From 51e96bfe5972d36998abddaffb99304272e7be1c Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Tue, 7 Nov 2023 16:18:11 +1100 Subject: [PATCH 23/26] Move all computation into thread worker, allow previously computed embeddings to exist --- micro_sam/sam_annotator/_widgets.py | 53 +++++++++++++++++------------ 1 file changed, 31 insertions(+), 22 deletions(-) diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index 7201b7c3..c117555b 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -5,6 +5,8 @@ from magicgui import magic_factory, widgets from napari.qt.threading import thread_worker +import zarr +from zarr.errors import PathNotFoundError from micro_sam.sam_annotator._state import AnnotatorState from micro_sam.util import ( @@ -14,6 +16,7 @@ _MODEL_URLS, _DEFAULT_MODEL, _available_devices, + _compute_data_signature, ) if TYPE_CHECKING: @@ -38,38 +41,44 @@ def embedding_widget( optional_custom_weights: Optional[Path] = None, # A filepath or URL to custom model weights. ) -> ImageEmbeddings: """Image embedding widget.""" - # Make sure save directory exists and is an empty directory - if save_path is not None: - os.makedirs(save_path, exist_ok=True) - if not save_path.is_dir(): - raise NotADirectoryError( - f"The user selected 'save_path' is not a direcotry: {save_path}" - ) - if len(os.listdir(save_path)) > 0: - raise RuntimeError( - f"The user selected 'save_path' is not empty: {save_path}" - ) - state = AnnotatorState() - # Initialize the model - state.predictor = get_sam_model(device=device, model_type=model.name, - checkpoint_path=optional_custom_weights) # Get image dimensions ndim = image.data.ndim if image.rgb: ndim -= 1 - # Compute the image embeddings - @thread_worker(connect={'started': pbar.show, 'returned': pbar.hide}) - def _compute_image_embedding(state, image_data, save_path, ndim=None): + @thread_worker(connect={'started': pbar.show, 'finished': pbar.hide}) + def _compute_image_embedding(state, image_data, save_path, ndim=None, + device="auto", model=Model.__getitem__(_DEFAULT_MODEL), + optional_custom_weights=None): + # Make sure save directory exists and is an empty directory if save_path is not None: - save_path = str(save_path) + os.makedirs(save_path, exist_ok=True) + if not save_path.is_dir(): + raise NotADirectoryError( + f"The user selected 'save_path' is not a direcotry: {save_path}" + ) + if len(os.listdir(save_path)) > 0: + try: + zarr.open(save_path, "r") + except PathNotFoundError: + raise RuntimeError( + "The user selected 'save_path' is not a zarr array " + f"or empty directory: {save_path}" + ) + # Initialize the model + state.predictor = get_sam_model(device=device, model_type=model.name, + checkpoint_path=optional_custom_weights) + # Compute the image embeddings state.image_embeddings = precompute_image_embeddings( predictor = state.predictor, input_ = image_data, - save_path = save_path, + save_path = str(save_path), ndim=ndim, ) - return state.image_embeddings # returns napari._qt.qthreading.FunctionWorker + data_signature = _compute_data_signature(image_data) + state.data_signature = data_signature + state.image_shape = image_data.shape + return state # returns napari._qt.qthreading.FunctionWorker - return _compute_image_embedding(state, image.data, save_path, ndim=ndim) + return _compute_image_embedding(state, image.data, save_path, ndim=ndim, device=device, model=model, optional_custom_weights=optional_custom_weights) From c0c83a48f64b0277baa54d0ef1c219e4ff690af1 Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Tue, 7 Nov 2023 16:22:05 +1100 Subject: [PATCH 24/26] Add reset_state method to clear all attributes held in state --- micro_sam/sam_annotator/_state.py | 12 ++++++++++++ micro_sam/sam_annotator/_widgets.py | 1 + 2 files changed, 13 insertions(+) diff --git a/micro_sam/sam_annotator/_state.py b/micro_sam/sam_annotator/_state.py index 2642f2d1..f625be6c 100644 --- a/micro_sam/sam_annotator/_state.py +++ b/micro_sam/sam_annotator/_state.py @@ -28,6 +28,7 @@ class AnnotatorState(metaclass=Singleton): image_embeddings: Optional[ImageEmbeddings] = None predictor: Optional[SamPredictor] = None image_shape: Optional[Tuple[int, int]] = None + data_signature: Optional[str] = None # amg: needs to be initialized for the automatic segmentation functionality. # amg_state: for storing the instance segmentation state for the 3d segmentation tool. @@ -67,3 +68,14 @@ def initialized_for_tracking(self): f"Invalid AnnotatorState: {init_sum} / 2 parts of the state " "needed for tracking are initialized." ) + + def reset_state(self): + """Reset state, clear all attributes.""" + self.image_embeddings = None + self.predictor = None + self.image_shape = None + self.data_signature = None + self.amg = None + self.amg_state = None + self.current_track_id = None + self.lineage = None diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index c117555b..926fb228 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -42,6 +42,7 @@ def embedding_widget( ) -> ImageEmbeddings: """Image embedding widget.""" state = AnnotatorState() + state.reset_state() # Get image dimensions ndim = image.data.ndim if image.rgb: From a4ad1dc77aab3af417f852dfbe1525e9255c3a6b Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Tue, 7 Nov 2023 16:30:53 +1100 Subject: [PATCH 25/26] Remove data_signature attribute from AnnotatorState attributes --- micro_sam/sam_annotator/_state.py | 2 -- micro_sam/sam_annotator/_widgets.py | 3 --- 2 files changed, 5 deletions(-) diff --git a/micro_sam/sam_annotator/_state.py b/micro_sam/sam_annotator/_state.py index f625be6c..9bb9cab5 100644 --- a/micro_sam/sam_annotator/_state.py +++ b/micro_sam/sam_annotator/_state.py @@ -28,7 +28,6 @@ class AnnotatorState(metaclass=Singleton): image_embeddings: Optional[ImageEmbeddings] = None predictor: Optional[SamPredictor] = None image_shape: Optional[Tuple[int, int]] = None - data_signature: Optional[str] = None # amg: needs to be initialized for the automatic segmentation functionality. # amg_state: for storing the instance segmentation state for the 3d segmentation tool. @@ -74,7 +73,6 @@ def reset_state(self): self.image_embeddings = None self.predictor = None self.image_shape = None - self.data_signature = None self.amg = None self.amg_state = None self.current_track_id = None diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index 926fb228..241e4bc1 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -16,7 +16,6 @@ _MODEL_URLS, _DEFAULT_MODEL, _available_devices, - _compute_data_signature, ) if TYPE_CHECKING: @@ -77,8 +76,6 @@ def _compute_image_embedding(state, image_data, save_path, ndim=None, save_path = str(save_path), ndim=ndim, ) - data_signature = _compute_data_signature(image_data) - state.data_signature = data_signature state.image_shape = image_data.shape return state # returns napari._qt.qthreading.FunctionWorker From ca3c77e4a1418246a153d068f6514a5789ba6161 Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Wed, 8 Nov 2023 10:40:31 +1100 Subject: [PATCH 26/26] Embedding widget, image_shape in annotator state --- micro_sam/sam_annotator/_widgets.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index 241e4bc1..887aae8a 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -43,9 +43,12 @@ def embedding_widget( state = AnnotatorState() state.reset_state() # Get image dimensions - ndim = image.data.ndim if image.rgb: - ndim -= 1 + ndim = image.data.ndim - 1 + state.image_shape = image.data.shape[:-1] + else: + ndim = image.data.ndim + state.image_shape = image.data.shape @thread_worker(connect={'started': pbar.show, 'finished': pbar.hide}) def _compute_image_embedding(state, image_data, save_path, ndim=None, @@ -76,7 +79,6 @@ def _compute_image_embedding(state, image_data, save_path, ndim=None, save_path = str(save_path), ndim=ndim, ) - state.image_shape = image_data.shape return state # returns napari._qt.qthreading.FunctionWorker return _compute_image_embedding(state, image.data, save_path, ndim=ndim, device=device, model=model, optional_custom_weights=optional_custom_weights)