Skip to content

Commit

Permalink
get patch/target patch sizes based on image size; put them to hdf att…
Browse files Browse the repository at this point in the history
…ributes
  • Loading branch information
mese79 committed Jan 26, 2024
1 parent d5f67ff commit 85d7d89
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 115 deletions.
31 changes: 25 additions & 6 deletions src/napari_sam_labeling_tools/_embedding_extractor_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
from .utils.data import (
get_stack_sizes
)
from .utils.extract import get_sam_embeddings_for_slice
from .utils.extract import (
get_patch_sizes,
get_sam_embeddings_for_slice
)


class EmbeddingExtractorWidget(QWidget):
Expand Down Expand Up @@ -67,7 +70,8 @@ def prepare_widget(self):
self.stop_button.setMinimumWidth(150)
# progress
self.stack_progress = QProgressBar()
# self.slice_progress = QProgressBar()
# patch info
self.patch_label = QLabel("Patch Sizes:")

self.viewer.layers.events.inserted.connect(self.check_input_layers)
self.viewer.layers.events.removed.connect(self.check_input_layers)
Expand Down Expand Up @@ -100,7 +104,7 @@ def prepare_widget(self):
vbox = QVBoxLayout()
vbox.setContentsMargins(0, 5, 0, 0)
vbox.addWidget(self.stack_progress)
# vbox.addWidget(self.slice_progress)
vbox.addWidget(self.patch_label)
layout.addLayout(vbox)

gbox = QGroupBox()
Expand Down Expand Up @@ -144,30 +148,45 @@ def extract_embeddings(self):
if storage_path is None or len(storage_path) < 6:
notif.show_error("No storage path was set.")
return
# get proper patch sizes
_, img_height, img_width = get_stack_sizes(image_layer.data)
patch_size, target_patch_size = get_patch_sizes(img_height, img_width)
self.patch_label.setText(f"Patch Sizes: ({patch_size}, {target_patch_size})")
self.update()

self.extract_button.setEnabled(False)
self.stop_button.setEnabled(True)
self.extract_worker = create_worker(
self.get_stack_sam_embeddings, image_layer, storage_path
self.get_stack_sam_embeddings,
image_layer, storage_path, patch_size, target_patch_size
)
self.extract_worker.yielded.connect(self.update_extract_progress)
self.extract_worker.finished.connect(self.extract_is_done)
self.extract_worker.run()

def get_stack_sam_embeddings(self, image_layer, storage_path):
def get_stack_sam_embeddings(
self, image_layer, storage_path, patch_size, target_patch_size
):
# initial sam model
sam_model, device = SAM.setup_lighthq_sam_model()
# initial storage hdf5 file
self.storage = h5py.File(storage_path, "w")
# get sam embeddings slice by slice and save them into storage file
num_slices, img_height, img_width = get_stack_sizes(image_layer.data)
self.storage.attrs["num_slices"] = num_slices
self.storage.attrs["img_height"] = img_height
self.storage.attrs["img_width"] = img_width
self.storage.attrs["patch_size"] = patch_size
self.storage.attrs["target_patch_size"] = target_patch_size

for slice_index in np_progress(
range(num_slices), desc="get embeddings for slices"
):
image = image_layer.data[slice_index] if num_slices > 1 else image_layer.data
slice_grp = self.storage.create_group(str(slice_index))
get_sam_embeddings_for_slice(
sam_model.image_encoder, device, image, slice_grp
image, patch_size, target_patch_size,
sam_model.image_encoder, device, slice_grp
)

yield (slice_index, num_slices)
Expand Down
20 changes: 12 additions & 8 deletions src/napari_sam_labeling_tools/_sam_prompt_segmentation_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
get_layer,
)
from .utils.data import (
IMAGE_PATCH_SIZE, TARGET_PATCH_SIZE,
get_stack_sizes, get_patch_indices,
get_num_target_patches, is_image_rgb,
)
Expand All @@ -49,6 +48,8 @@ def __init__(self, napari_viewer: napari.Viewer):
self.device = None
self.sam_model = None
self.sam_predictor = None
self.patch_size = 512
self.target_patch_size = 128

self.prepare_widget()

Expand Down Expand Up @@ -306,6 +307,9 @@ def select_storage(self):
self.storage_textbox.setText(selected_file)
# load the storage
self.storage = h5py.File(selected_file, "r")
self.patch_size = self.storage.attrs.get("patch_size", self.patch_size)
self.target_patch_size = self.storage.attrs.get(
"target_patch_size", self.target_patch_size)

def get_user_prompts(self):
user_prompts = None
Expand Down Expand Up @@ -395,29 +399,29 @@ def calc_similarity_matrix(self, prompts_mask, curr_slice):
all_coords = prompt_mask_positions[slice_mask] # n x 3 (slice, y, x)
patch_indices = get_patch_indices(
all_coords[:, 1:], img_height, img_width,
IMAGE_PATCH_SIZE, TARGET_PATCH_SIZE
self.patch_size, self.target_patch_size
)
for p_i in np_progress(np.unique(patch_indices), desc="reading patches"):
patch_features = slice_dataset[p_i]
patch_coords = all_coords[patch_indices == p_i]
prompt_avg_vector += patch_features[
patch_coords[:, 1] % TARGET_PATCH_SIZE,
patch_coords[:, 2] % TARGET_PATCH_SIZE
patch_coords[:, 1] % self.target_patch_size,
patch_coords[:, 2] % self.target_patch_size
].sum(axis=0)
prompt_avg_vector /= len(prompt_mask_positions)

# shape: N x target_size x target_size x C
curr_slice_features = self.storage[str(curr_slice)]["sam"][:]
patch_rows, patch_cols = get_num_target_patches(
img_height, img_width, IMAGE_PATCH_SIZE, TARGET_PATCH_SIZE
img_height, img_width, self.patch_size, self.target_patch_size
)
# reshape it to the image size + padding
curr_slice_features = curr_slice_features.reshape(
patch_rows, patch_cols, TARGET_PATCH_SIZE, TARGET_PATCH_SIZE, -1
patch_rows, patch_cols, self.target_patch_size, self.target_patch_size, -1
)
curr_slice_features = np.moveaxis(curr_slice_features, 1, 2).reshape(
patch_rows * TARGET_PATCH_SIZE,
patch_cols * TARGET_PATCH_SIZE,
patch_rows * self.target_patch_size,
patch_cols * self.target_patch_size,
-1
)
# skip paddings
Expand Down
20 changes: 12 additions & 8 deletions src/napari_sam_labeling_tools/_sam_rf_segmentation_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
get_layer,
)
from .utils.data import (
IMAGE_PATCH_SIZE, TARGET_PATCH_SIZE,
get_stack_sizes, get_patch_indices,
get_num_target_patches
)
Expand All @@ -50,6 +49,8 @@ def __init__(self, napari_viewer: napari.Viewer):
self.rf_model = None
self.device = None
self.sam_model = None
self.patch_size = 512
self.target_patch_size = 128

self.prepare_widget()

Expand Down Expand Up @@ -349,6 +350,9 @@ def select_storage(self):
self.storage_textbox.setText(selected_file)
# load the storage
self.storage = h5py.File(selected_file, "r")
self.patch_size = self.storage.attrs.get("patch_size", self.patch_size)
self.target_patch_size = self.storage.attrs.get(
"target_patch_size", self.target_patch_size)

def add_labels_layer(self):
self.image_layer = get_layer(
Expand Down Expand Up @@ -424,16 +428,16 @@ def get_train_data(self):
][:, 1:] # omit the slice dim
patch_indices = get_patch_indices(
slice_coords, img_height, img_width,
IMAGE_PATCH_SIZE, TARGET_PATCH_SIZE
self.patch_size, self.target_patch_size
)
grp_key = str(slice_index)
slice_dataset = self.storage[grp_key]["sam"]
for p_i in np.unique(patch_indices):
patch_coords = slice_coords[patch_indices == p_i]
patch_features = slice_dataset[p_i]
train_data[count: count + len(patch_coords)] = patch_features[
patch_coords[:, 0] % TARGET_PATCH_SIZE,
patch_coords[:, 1] % TARGET_PATCH_SIZE
patch_coords[:, 0] % self.target_patch_size,
patch_coords[:, 1] % self.target_patch_size
]
labels[
count: count + len(patch_coords)
Expand Down Expand Up @@ -594,14 +598,14 @@ def predict_slice(self, rf_model, slice_index, img_height, img_width):
segmentation_image = np.vstack(segmentation_image)
# reshape into the image size + padding
patch_rows, patch_cols = get_num_target_patches(
img_height, img_width, IMAGE_PATCH_SIZE, TARGET_PATCH_SIZE
img_height, img_width, self.patch_size, self.target_patch_size
)
segmentation_image = segmentation_image.reshape(
patch_rows, patch_cols, TARGET_PATCH_SIZE, TARGET_PATCH_SIZE
patch_rows, patch_cols, self.target_patch_size, self.target_patch_size
)
segmentation_image = np.moveaxis(segmentation_image, 1, 2).reshape(
patch_rows * TARGET_PATCH_SIZE,
patch_cols * TARGET_PATCH_SIZE
patch_rows * self.target_patch_size,
patch_cols * self.target_patch_size
)
# skip paddings
segmentation_image = segmentation_image[:img_height, :img_width]
Expand Down
Loading

0 comments on commit 85d7d89

Please sign in to comment.