Skip to content

Commit

Permalink
added sam predictor widget
Browse files Browse the repository at this point in the history
  • Loading branch information
mese79 committed Jan 11, 2024
1 parent a46e2c1 commit b4cf6e6
Show file tree
Hide file tree
Showing 5 changed files with 370 additions and 18 deletions.
2 changes: 2 additions & 0 deletions src/napari_sam_labeling_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
__version__ = "0.0.1"
from ._embedding_extractor_widget import EmbeddingExtractorWidget
from ._sam_predictor_widget import SAMPredictorWidget

__all__ = (
"EmbeddingExtractorWidget",
"SAMPredictorWidget",
)
48 changes: 30 additions & 18 deletions src/napari_sam_labeling_tools/_embedding_extractor_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,10 @@
QFileDialog, QScrollArea, QProgressBar,
)
from qtpy.QtCore import Qt
from qtpy.QtGui import QIntValidator, QDoubleValidator

import h5py
import numpy as np
import torch
from torchvision import transforms

from .widgets import (
ScrollWidgetWrapper,
Expand All @@ -26,19 +24,15 @@
from .utils import (
config
)
from .utils.data import (
DATA_PATCH_SIZE, TARGET_PATCH_SIZE,
patchify, unpatchify
)
from .utils.extract import get_sam_embeddings_for_slice


class EmbeddingExtractorWidget(QWidget):
def __init__(self, napari_viewer: napari.Viewer):
super().__init__()
self.viewer = napari_viewer
self.sam_model = None
self.device = None
self.extract_worker = None
self.storage = None

self.prepare_widget()

Expand All @@ -64,6 +58,10 @@ def prepare_widget(self):
self.extract_button = QPushButton("Extract Embeddings")
self.extract_button.setEnabled(False)
self.extract_button.clicked.connect(self.extract_embeddings)
# stop button
self.stop_button = QPushButton("Stop")
self.stop_button.setEnabled(False)
self.stop_button.clicked.connect(self.stop_extracting)
# progress
self.stack_progress = QProgressBar()
# self.slice_progress = QProgressBar()
Expand All @@ -90,6 +88,7 @@ def prepare_widget(self):
hbox.addWidget(storage_button)
vbox.addLayout(hbox)
vbox.addWidget(self.extract_button)
vbox.addWidget(self.stop_button)
layout.addLayout(vbox)

vbox = QVBoxLayout()
Expand All @@ -99,7 +98,7 @@ def prepare_widget(self):
layout.addLayout(vbox)

gbox = QGroupBox()
gbox.setTitle("Inputs")
gbox.setTitle("Extractor Widget")
gbox.setMinimumWidth(100)
gbox.setLayout(layout)
self.base_layout.addWidget(gbox)
Expand Down Expand Up @@ -141,34 +140,46 @@ def extract_embeddings(self):
return

self.extract_button.setEnabled(False)
worker = create_worker(self.get_stack_sam_embeddings, image_layer, storage_path)
worker.yielded.connect(self.update_extract_progress)
worker.finished.connect(self.extract_is_done)
worker.run()
self.stop_button.setEnabled(True)
self.extract_worker = create_worker(
self.get_stack_sam_embeddings, image_layer, storage_path
)
self.extract_worker.yielded.connect(self.update_extract_progress)
self.extract_worker.finished.connect(self.extract_is_done)
self.extract_worker.run()

def get_stack_sam_embeddings(self, image_layer, storage_path):
# initial sam model
self.sam_model, self.device = SAM.setup_lighthq_sam_model()
sam_model, device = SAM.setup_lighthq_sam_model()
# initial storage hdf5 file
storage = h5py.File(storage_path, "w")
self.storage = h5py.File(storage_path, "w")
# get sam embeddings slice by slice and save them into storage file
num_slices, img_height, img_width = image_layer.data.shape
for slice_index in np_progress(
range(num_slices), desc="get embeddings for slices"
):
slice_grp = storage.create_group(str(slice_index))
slice_grp = self.storage.create_group(str(slice_index))
dataset = slice_grp.create_dataset(
"sam",
shape=(img_height, img_width, SAM.EMBEDDING_SIZE + SAM.PATCH_CHANNELS)
)
dataset[...] = get_sam_embeddings_for_slice(
self.sam_model.image_encoder, self.device,
sam_model.image_encoder, device,
image_layer, slice_index
)

yield (slice_index, num_slices)

storage.close()
self.storage.close()

def stop_extracting(self):
if self.extract_worker is not None:
self.extract_worker.quit()
self.extract_worker = None
if isinstance(self.storage, h5py.File):
self.storage.close()
self.storage = None
self.stop_button.setEnabled(False)

def update_extract_progress(self, values):
curr, total = values
Expand All @@ -179,5 +190,6 @@ def update_extract_progress(self, values):

def extract_is_done(self):
self.extract_button.setEnabled(True)
self.stop_button.setEnabled(False)
print("Extracting is done!")
notif.show_info("Extracting is done!")
Loading

0 comments on commit b4cf6e6

Please sign in to comment.