Skip to content

Commit

Permalink
changed to patch-based storage; fixed RGB support
Browse files Browse the repository at this point in the history
  • Loading branch information
mese79 committed Jan 25, 2024
1 parent 639f2df commit 9c683f2
Show file tree
Hide file tree
Showing 8 changed files with 219 additions and 163 deletions.
6 changes: 3 additions & 3 deletions src/napari_sam_labeling_tools/SAM/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@

INPUT_SIZE = 1024
FEATURE_H = FEATURE_W = 64
EMBEDDING_SIZE = 256
ENCODER_OUT_CHANNELS = 256
PATCH_SIZE = 256
PATCH_CHANNELS = 64
EMBED_PATCH_CHANNELS = 64


sam_transform = transforms.Compose([
transforms.Resize(
(INPUT_SIZE, INPUT_SIZE),
interpolation=transforms.InterpolationMode.BILINEAR,
interpolation=transforms.InterpolationMode.BICUBIC,
antialias=True
),
])
9 changes: 2 additions & 7 deletions src/napari_sam_labeling_tools/_embedding_extractor_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,15 +165,10 @@ def get_stack_sam_embeddings(self, image_layer, storage_path):
range(num_slices), desc="get embeddings for slices"
):
image = image_layer.data[slice_index] if num_slices > 1 else image_layer.data
sam_embeddings = get_sam_embeddings_for_slice(
sam_model.image_encoder, device, image
)
slice_grp = self.storage.create_group(str(slice_index))
dataset = slice_grp.create_dataset(
"sam", shape=sam_embeddings.shape
get_sam_embeddings_for_slice(
sam_model.image_encoder, device, image, slice_grp
)
# shape: patch_rows x patch_cols x target_size x target_size x C
dataset[...] = sam_embeddings

yield (slice_index, num_slices)

Expand Down
36 changes: 20 additions & 16 deletions src/napari_sam_labeling_tools/_sam_predictor_widget.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import napari
import napari.utils.notifications as notif
from napari.utils.events import Event
from napari.utils import progress as np_progress

from qtpy.QtWidgets import (
QVBoxLayout, QWidget, QHBoxLayout,
Expand All @@ -19,7 +20,7 @@
from .utils import (
colormaps, config
)
from .utils.data import get_stack_sizes
from .utils.data import get_stack_sizes, is_image_rgb


class SAMPredictorWidget(QWidget):
Expand Down Expand Up @@ -250,16 +251,23 @@ def get_prompt_labels(self, user_prompts, num_slices, img_height, img_width):
prompts_merged_mask = np.zeros(
(num_slices, img_height, img_width), dtype=np.uint8
)
for prompt in user_prompts:
for prompt in np_progress(user_prompts, desc="getting prompts mask"):
# prepare the image for sam
slice_index = int(prompt[0, 0]) if is_box_prompt else int(prompt[0])
if num_slices > 1:
input_img = self.image_layer.data[slice_index]
else:
input_img = self.image_layer.data
if not is_image_rgb(input_img):
input_img = np.repeat(
self.image_layer.data[slice_index, :, :, np.newaxis], 3,
axis=-1
)
self.sam_predictor.set_image(input_img)

# first dim of prompt is the slice index.
# sam prompt need to be as x,y coordinates (numpy is y,x).
if is_box_prompt:
slice_index = prompt[0, 0].astype(np.int32)
input_img = np.repeat(
self.image_layer.data[slice_index, :, :, np.newaxis],
3, axis=-1
)
self.sam_predictor.set_image(input_img)
# napari box: depends on direction of drawing :( (y, x)
# SAM box: top-left, bottom-right (x, y)
top_left = (prompt[:, 2].min(), prompt[:, 1].min())
Expand All @@ -273,12 +281,6 @@ def get_prompt_labels(self, user_prompts, num_slices, img_height, img_width):
hq_token_only=False,
)
else:
slice_index = prompt[0].astype(np.int32)
input_img = np.repeat(
self.image_layer.data[slice_index, :, :, np.newaxis],
3, axis=-1
)
self.sam_predictor.set_image(input_img)
point = prompt[1:][[1, 0]]
masks, scores, logits = self.sam_predictor.predict(
point_coords=point[np.newaxis, :],
Expand Down Expand Up @@ -307,9 +309,11 @@ def predict_prompts(self):
notif.show_warning("No prompts was given!")
return

num_slices, img_height, img_width = get_stack_sizes(self.image_layer)
num_slices, img_height, img_width = get_stack_sizes(self.image_layer.data)
if self.new_layer_checkbox.checkState() == Qt.Checked:
segmentation_data = np.zeros((img_height, img_width), dtype=np.uint8)
segmentation_data = np.zeros(
(num_slices, img_height, img_width), dtype=np.uint8
)
self.segmentation_layer = self.viewer.add_labels(
segmentation_data, name="Prompt Segmentations"
)
Expand Down
106 changes: 63 additions & 43 deletions src/napari_sam_labeling_tools/_sam_prompt_segmentation_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from napari.utils.events import Event
from napari.qt.threading import create_worker
from napari.utils import progress as np_progress
from napari.utils import Colormap
# from napari.utils import Colormap

from qtpy.QtWidgets import (
QHBoxLayout, QVBoxLayout, QWidget,
Expand All @@ -24,8 +24,9 @@
get_layer,
)
from .utils.data import (
TARGET_PATCH_SIZE,
get_stack_sizes, get_patch_position
IMAGE_PATCH_SIZE, TARGET_PATCH_SIZE,
get_stack_sizes, get_patch_indices,
get_num_target_patches, is_image_rgb,
)
from .utils import (
colormaps, config
Expand Down Expand Up @@ -334,15 +335,22 @@ def get_prompt_labels(self, num_slices, img_height, img_width):
(num_slices, img_height, img_width), dtype=np.uint8
)
for prompt in user_prompts:
# prepare the image for sam
slice_index = int(prompt[0, 0]) if is_box_prompt else int(prompt[0])
if num_slices > 1:
input_img = self.image_layer.data[slice_index]
else:
input_img = self.image_layer.data
if not is_image_rgb(input_img):
input_img = np.repeat(
self.image_layer.data[slice_index, :, :, np.newaxis], 3,
axis=-1
)
self.sam_predictor.set_image(input_img)

# first dim of prompt is the slice index.
# sam prompt need to be as x,y coordinates (numpy is y,x).
if is_box_prompt:
slice_index = prompt[0, 0].astype(np.int32)
input_image = np.repeat(
self.image_layer.data[slice_index, :, :, np.newaxis],
3, axis=-1
)
self.sam_predictor.set_image(input_image)
# napari box: depends on direction of drawing :( (y, x)
# SAM box: top-left, bottom-right (x, y)
top_left = (prompt[:, 2].min(), prompt[:, 1].min())
Expand All @@ -356,12 +364,6 @@ def get_prompt_labels(self, num_slices, img_height, img_width):
hq_token_only=False,
)
else:
slice_index = prompt[0].astype(np.int32)
input_image = np.repeat(
self.image_layer.data[slice_index, :, :, np.newaxis],
3, axis=-1
)
self.sam_predictor.set_image(input_image)
point = prompt[1:][[1, 0]]
masks, scores, logits = self.sam_predictor.predict(
point_coords=point[np.newaxis, :],
Expand All @@ -381,34 +383,44 @@ def get_similarity_matrix(self, prompts_mask, curr_slice):
Calculate Cosine similarity for all pixels with prompts' mask average vector
(in sam's embedding space).
"""
prompt_avg_vector = np.zeros(SAM.EMBEDDING_SIZE + SAM.PATCH_CHANNELS)
_, img_height, img_width = get_stack_sizes(self.image_layer.data)
total_channels = SAM.ENCODER_OUT_CHANNELS + SAM.EMBED_PATCH_CHANNELS
prompt_avg_vector = np.zeros(total_channels)
prompt_mask_positions = np.argwhere(prompts_mask == 1)
for index in np.unique(prompt_mask_positions[:, 0]):
for index in np_progress(
np.unique(prompt_mask_positions[:, 0]), desc="Calc. similarity matrix"
):
slice_dataset = self.storage[str(index)]["sam"]
coords_in_slice = prompt_mask_positions[:, 0] == index
for pos in prompt_mask_positions[coords_in_slice]:
y = pos[1]
x = pos[2]
# get patch including pixel position
patch_row, patch_col = get_patch_position(y, x)
patch_features = slice_dataset[patch_row, patch_col]
# sum pixel embeddings
slice_mask = prompt_mask_positions[:, 0] == index
all_coords = prompt_mask_positions[slice_mask] # n x 3 (slice, y, x)
patch_indices = get_patch_indices(
all_coords[:, 1:], img_height, img_width,
IMAGE_PATCH_SIZE, TARGET_PATCH_SIZE
)
for p_i in np_progress(np.unique(patch_indices), desc="reading patches"):
patch_features = slice_dataset[p_i]
patch_coords = all_coords[patch_indices == p_i]
prompt_avg_vector += patch_features[
y % TARGET_PATCH_SIZE, x % TARGET_PATCH_SIZE
]
patch_coords[:, 1] % TARGET_PATCH_SIZE,
patch_coords[:, 2] % TARGET_PATCH_SIZE
].sum(axis=0)
prompt_avg_vector /= len(prompt_mask_positions)

# shape: patch_rows x patch_cols x target_size x target_size x C
# shape: N x target_size x target_size x C
curr_slice_features = self.storage[str(curr_slice)]["sam"][:]
patch_rows, patch_cols = curr_slice_features.shape[:2]
patch_rows, patch_cols = get_num_target_patches(
img_height, img_width, IMAGE_PATCH_SIZE, TARGET_PATCH_SIZE
)
# reshape it to the image size + padding
curr_slice_features = curr_slice_features.transpose([0, 2, 1, 3, 4]).reshape(
curr_slice_features = curr_slice_features.reshape(
patch_rows, patch_cols, TARGET_PATCH_SIZE, TARGET_PATCH_SIZE, -1
)
curr_slice_features = np.moveaxis(curr_slice_features, 1, 2).reshape(
patch_rows * TARGET_PATCH_SIZE,
patch_cols * TARGET_PATCH_SIZE,
-1
)
# skip paddings
_, img_height, img_width = self.image_layer.data.shape
curr_slice_features = curr_slice_features[:img_height, :img_width]
# calc. cosine similarity
sim_mat = np.dot(curr_slice_features, prompt_avg_vector)
Expand Down Expand Up @@ -436,10 +448,10 @@ def predict(self, whole_stack=False):
notif.show_error("No storage is selected!")
return

num_slices, img_height, img_width = get_stack_sizes(self.image_layer)
num_slices, img_height, img_width = get_stack_sizes(self.image_layer.data)
if self.new_layer_checkbox.checkState() == Qt.Checked:
self.segmentation_layer = self.viewer.add_labels(
np.zeros((img_height, img_width), dtype=np.uint8),
np.zeros((num_slices, img_height, img_width), dtype=np.uint8),
name="Segmentations"
)
else:
Expand Down Expand Up @@ -469,12 +481,12 @@ def predict(self, whole_stack=False):
self.is_prompt_changed = False
if self.show_intermediate_checkbox.checkState() == Qt.Checked:
# add sam predictor result's layer
layer = self.viewer.add_labels(
self.viewer.add_labels(
data=self.prompts_mask, name="Prompt Labels", opacity=0.55
)
layer.colormap = Colormap(
np.array([[0.73, 0.48, 0.75, 1.0]])
)
# layer.colormap = Colormap(
# [[0.73, 0.48, 0.75, 1.0]]
# )

slice_indices = []
if not whole_stack:
Expand All @@ -494,7 +506,7 @@ def predict(self, whole_stack=False):
self.prediction_worker.run()

def run_prediction(self, slice_indices, prompts_mask):
for slice_index in np_progress(slice_indices):
for slice_index in np_progress(slice_indices, desc="Predicting slices"):
sim_mat = self.get_similarity_matrix(prompts_mask, slice_index)
high_sim_mask = np.zeros_like(sim_mat, dtype=np.uint8)
high_sim_mask[
Expand Down Expand Up @@ -567,11 +579,19 @@ def run_prediction(self, slice_indices, prompts_mask):

def predict_slice(self, slice_index, point_prompts):
sam_masks = []
input_image = np.repeat(
self.image_layer.data[slice_index, :, :, np.newaxis],
3, axis=-1
)
self.sam_predictor.set_image(input_image)
# prepare the image for sam
num_slices, img_height, img_width = get_stack_sizes(self.image_layer.data)
if num_slices > 1:
input_img = self.image_layer.data[slice_index]
else:
input_img = self.image_layer.data
if not is_image_rgb(input_img):
input_img = np.repeat(
self.image_layer.data[slice_index, :, :, np.newaxis], 3,
axis=-1
)
self.sam_predictor.set_image(input_img)

for _, point in enumerate(point_prompts):
point_labels = np.array([1])
point_coords = point[np.newaxis, :]
Expand Down
Loading

0 comments on commit 9c683f2

Please sign in to comment.