Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP refactor state in singleton #3

Draft
wants to merge 3 commits into
base: embedding-widget
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions check_embed_widget.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import napari
from skimage.data import astronaut
from micro_sam.sam_annotator._state import SamState
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just fyi, I have called this AnnotatorState in computational-cell-analytics#240 now.


x = astronaut()

v = napari.Viewer()
v.add_image(x)


@v.bind_key("p")
def print_embed(v):
state = SamState()
print("Image Embeddings are None:", state.image_embeddings is None)


napari.run()
10 changes: 5 additions & 5 deletions micro_sam/sam_annotator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
The interactive annotation tools.
"""

from .annotator import annotator
from .annotator_2d import annotator_2d
from .annotator_3d import annotator_3d
from .annotator_tracking import annotator_tracking
from .image_series_annotator import image_folder_annotator, image_series_annotator
# from .annotator import annotator
# from .annotator_2d import annotator_2d
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are these lines commented out?

This is what causedtest_gui.py to fail, since it used to do from micro_sam.sam_annotator import annotator_2d. I've changed the import in test_gui.py to now say from micro_sam.sam_annotator.annotator_2d import annotator_2d. So now it should pass, regardless of whatever you decide to do here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to comment these out locally in order to fix the vigra import issues.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to undo these changes. I assume these import errors are specific to my installation and I will look into this more closely at some later point (and first see if this still persists with a new environment.)

# from .annotator_3d import annotator_3d
# from .annotator_tracking import annotator_tracking
# from .image_series_annotator import image_folder_annotator, image_series_annotator
20 changes: 20 additions & 0 deletions micro_sam/sam_annotator/_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@


# See
# https://stackoverflow.com/questions/6760685/creating-a-singleton-in-python
def singleton(class_):
instances = {}

def getinstance(*args, **kwargs):
if class_ not in instances:
instances[class_] = class_(*args, **kwargs)
return instances[class_]
return getinstance


# we probably want this to be a data class
# and I am not sure which singleton pattern to go with
@singleton
class SamState:
def __init__(self):
self.image_embeddings = None
9 changes: 6 additions & 3 deletions micro_sam/sam_annotator/_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
_DEFAULT_MODEL,
_available_devices,
)
from ._state import SamState

if TYPE_CHECKING:
import napari
Expand Down Expand Up @@ -49,15 +50,17 @@ def embedding_widget(
# Compute the image embeddings
@thread_worker(connect={'started': pbar.show, 'returned': pbar.hide})
def _compute_image_embedding(PREDICTOR, image_data, save_path, ndim=None):
print("Start")
if save_path is not None:
save_path = str(save_path)
global IMAGE_EMBEDDINGS
IMAGE_EMBEDDINGS = precompute_image_embeddings(
state = SamState()
state.image_embeddings = precompute_image_embeddings(
predictor = PREDICTOR,
input_ = image_data,
save_path = save_path,
ndim=ndim,
)
return IMAGE_EMBEDDINGS # returns napari._qt.qthreading.FunctionWorker
print("Done")
return state # returns napari._qt.qthreading.FunctionWorker

return _compute_image_embedding(PREDICTOR, image.data, save_path, ndim=ndim)
3 changes: 1 addition & 2 deletions test/test_gui.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import numpy as np
import pytest

from micro_sam.sam_annotator import annotator_2d
from micro_sam.sam_annotator.annotator_2d import _initialize_viewer
from micro_sam.sam_annotator.annotator_2d import annotator_2d, _initialize_viewer


def _check_layer_initialization(viewer):
Expand Down