From 85d7d89caaaa1ce833f00930e39d4fdcabfe335f Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Fri, 26 Jan 2024 15:17:07 +0100 Subject: [PATCH] get patch/target patch sizes based on image size; put them to hdf attributes --- .../_embedding_extractor_widget.py | 31 +++- .../_sam_prompt_segmentation_widget.py | 20 ++- .../_sam_rf_segmentation_widget.py | 20 ++- src/napari_sam_labeling_tools/utils/data.py | 169 +++++++++--------- .../utils/extract.py | 31 +++- 5 files changed, 156 insertions(+), 115 deletions(-) diff --git a/src/napari_sam_labeling_tools/_embedding_extractor_widget.py b/src/napari_sam_labeling_tools/_embedding_extractor_widget.py index 35e4b96..f9973b3 100644 --- a/src/napari_sam_labeling_tools/_embedding_extractor_widget.py +++ b/src/napari_sam_labeling_tools/_embedding_extractor_widget.py @@ -25,7 +25,10 @@ from .utils.data import ( get_stack_sizes ) -from .utils.extract import get_sam_embeddings_for_slice +from .utils.extract import ( + get_patch_sizes, + get_sam_embeddings_for_slice +) class EmbeddingExtractorWidget(QWidget): @@ -67,7 +70,8 @@ def prepare_widget(self): self.stop_button.setMinimumWidth(150) # progress self.stack_progress = QProgressBar() - # self.slice_progress = QProgressBar() + # patch info + self.patch_label = QLabel("Patch Sizes:") self.viewer.layers.events.inserted.connect(self.check_input_layers) self.viewer.layers.events.removed.connect(self.check_input_layers) @@ -100,7 +104,7 @@ def prepare_widget(self): vbox = QVBoxLayout() vbox.setContentsMargins(0, 5, 0, 0) vbox.addWidget(self.stack_progress) - # vbox.addWidget(self.slice_progress) + vbox.addWidget(self.patch_label) layout.addLayout(vbox) gbox = QGroupBox() @@ -144,30 +148,45 @@ def extract_embeddings(self): if storage_path is None or len(storage_path) < 6: notif.show_error("No storage path was set.") return + # get proper patch sizes + _, img_height, img_width = get_stack_sizes(image_layer.data) + patch_size, target_patch_size = get_patch_sizes(img_height, img_width) + self.patch_label.setText(f"Patch Sizes: ({patch_size}, {target_patch_size})") + self.update() self.extract_button.setEnabled(False) self.stop_button.setEnabled(True) self.extract_worker = create_worker( - self.get_stack_sam_embeddings, image_layer, storage_path + self.get_stack_sam_embeddings, + image_layer, storage_path, patch_size, target_patch_size ) self.extract_worker.yielded.connect(self.update_extract_progress) self.extract_worker.finished.connect(self.extract_is_done) self.extract_worker.run() - def get_stack_sam_embeddings(self, image_layer, storage_path): + def get_stack_sam_embeddings( + self, image_layer, storage_path, patch_size, target_patch_size + ): # initial sam model sam_model, device = SAM.setup_lighthq_sam_model() # initial storage hdf5 file self.storage = h5py.File(storage_path, "w") # get sam embeddings slice by slice and save them into storage file num_slices, img_height, img_width = get_stack_sizes(image_layer.data) + self.storage.attrs["num_slices"] = num_slices + self.storage.attrs["img_height"] = img_height + self.storage.attrs["img_width"] = img_width + self.storage.attrs["patch_size"] = patch_size + self.storage.attrs["target_patch_size"] = target_patch_size + for slice_index in np_progress( range(num_slices), desc="get embeddings for slices" ): image = image_layer.data[slice_index] if num_slices > 1 else image_layer.data slice_grp = self.storage.create_group(str(slice_index)) get_sam_embeddings_for_slice( - sam_model.image_encoder, device, image, slice_grp + image, patch_size, target_patch_size, + sam_model.image_encoder, device, slice_grp ) yield (slice_index, num_slices) diff --git a/src/napari_sam_labeling_tools/_sam_prompt_segmentation_widget.py b/src/napari_sam_labeling_tools/_sam_prompt_segmentation_widget.py index 297e795..9403b9b 100644 --- a/src/napari_sam_labeling_tools/_sam_prompt_segmentation_widget.py +++ b/src/napari_sam_labeling_tools/_sam_prompt_segmentation_widget.py @@ -24,7 +24,6 @@ get_layer, ) from .utils.data import ( - IMAGE_PATCH_SIZE, TARGET_PATCH_SIZE, get_stack_sizes, get_patch_indices, get_num_target_patches, is_image_rgb, ) @@ -49,6 +48,8 @@ def __init__(self, napari_viewer: napari.Viewer): self.device = None self.sam_model = None self.sam_predictor = None + self.patch_size = 512 + self.target_patch_size = 128 self.prepare_widget() @@ -306,6 +307,9 @@ def select_storage(self): self.storage_textbox.setText(selected_file) # load the storage self.storage = h5py.File(selected_file, "r") + self.patch_size = self.storage.attrs.get("patch_size", self.patch_size) + self.target_patch_size = self.storage.attrs.get( + "target_patch_size", self.target_patch_size) def get_user_prompts(self): user_prompts = None @@ -395,29 +399,29 @@ def calc_similarity_matrix(self, prompts_mask, curr_slice): all_coords = prompt_mask_positions[slice_mask] # n x 3 (slice, y, x) patch_indices = get_patch_indices( all_coords[:, 1:], img_height, img_width, - IMAGE_PATCH_SIZE, TARGET_PATCH_SIZE + self.patch_size, self.target_patch_size ) for p_i in np_progress(np.unique(patch_indices), desc="reading patches"): patch_features = slice_dataset[p_i] patch_coords = all_coords[patch_indices == p_i] prompt_avg_vector += patch_features[ - patch_coords[:, 1] % TARGET_PATCH_SIZE, - patch_coords[:, 2] % TARGET_PATCH_SIZE + patch_coords[:, 1] % self.target_patch_size, + patch_coords[:, 2] % self.target_patch_size ].sum(axis=0) prompt_avg_vector /= len(prompt_mask_positions) # shape: N x target_size x target_size x C curr_slice_features = self.storage[str(curr_slice)]["sam"][:] patch_rows, patch_cols = get_num_target_patches( - img_height, img_width, IMAGE_PATCH_SIZE, TARGET_PATCH_SIZE + img_height, img_width, self.patch_size, self.target_patch_size ) # reshape it to the image size + padding curr_slice_features = curr_slice_features.reshape( - patch_rows, patch_cols, TARGET_PATCH_SIZE, TARGET_PATCH_SIZE, -1 + patch_rows, patch_cols, self.target_patch_size, self.target_patch_size, -1 ) curr_slice_features = np.moveaxis(curr_slice_features, 1, 2).reshape( - patch_rows * TARGET_PATCH_SIZE, - patch_cols * TARGET_PATCH_SIZE, + patch_rows * self.target_patch_size, + patch_cols * self.target_patch_size, -1 ) # skip paddings diff --git a/src/napari_sam_labeling_tools/_sam_rf_segmentation_widget.py b/src/napari_sam_labeling_tools/_sam_rf_segmentation_widget.py index 455843b..25c6210 100644 --- a/src/napari_sam_labeling_tools/_sam_rf_segmentation_widget.py +++ b/src/napari_sam_labeling_tools/_sam_rf_segmentation_widget.py @@ -26,7 +26,6 @@ get_layer, ) from .utils.data import ( - IMAGE_PATCH_SIZE, TARGET_PATCH_SIZE, get_stack_sizes, get_patch_indices, get_num_target_patches ) @@ -50,6 +49,8 @@ def __init__(self, napari_viewer: napari.Viewer): self.rf_model = None self.device = None self.sam_model = None + self.patch_size = 512 + self.target_patch_size = 128 self.prepare_widget() @@ -349,6 +350,9 @@ def select_storage(self): self.storage_textbox.setText(selected_file) # load the storage self.storage = h5py.File(selected_file, "r") + self.patch_size = self.storage.attrs.get("patch_size", self.patch_size) + self.target_patch_size = self.storage.attrs.get( + "target_patch_size", self.target_patch_size) def add_labels_layer(self): self.image_layer = get_layer( @@ -424,7 +428,7 @@ def get_train_data(self): ][:, 1:] # omit the slice dim patch_indices = get_patch_indices( slice_coords, img_height, img_width, - IMAGE_PATCH_SIZE, TARGET_PATCH_SIZE + self.patch_size, self.target_patch_size ) grp_key = str(slice_index) slice_dataset = self.storage[grp_key]["sam"] @@ -432,8 +436,8 @@ def get_train_data(self): patch_coords = slice_coords[patch_indices == p_i] patch_features = slice_dataset[p_i] train_data[count: count + len(patch_coords)] = patch_features[ - patch_coords[:, 0] % TARGET_PATCH_SIZE, - patch_coords[:, 1] % TARGET_PATCH_SIZE + patch_coords[:, 0] % self.target_patch_size, + patch_coords[:, 1] % self.target_patch_size ] labels[ count: count + len(patch_coords) @@ -594,14 +598,14 @@ def predict_slice(self, rf_model, slice_index, img_height, img_width): segmentation_image = np.vstack(segmentation_image) # reshape into the image size + padding patch_rows, patch_cols = get_num_target_patches( - img_height, img_width, IMAGE_PATCH_SIZE, TARGET_PATCH_SIZE + img_height, img_width, self.patch_size, self.target_patch_size ) segmentation_image = segmentation_image.reshape( - patch_rows, patch_cols, TARGET_PATCH_SIZE, TARGET_PATCH_SIZE + patch_rows, patch_cols, self.target_patch_size, self.target_patch_size ) segmentation_image = np.moveaxis(segmentation_image, 1, 2).reshape( - patch_rows * TARGET_PATCH_SIZE, - patch_cols * TARGET_PATCH_SIZE + patch_rows * self.target_patch_size, + patch_cols * self.target_patch_size ) # skip paddings segmentation_image = segmentation_image[:img_height, :img_width] diff --git a/src/napari_sam_labeling_tools/utils/data.py b/src/napari_sam_labeling_tools/utils/data.py index fa006e9..5145fe3 100644 --- a/src/napari_sam_labeling_tools/utils/data.py +++ b/src/napari_sam_labeling_tools/utils/data.py @@ -3,9 +3,8 @@ import torch.nn.functional as F -# PATCH_EMBEDDING_PATCH_SIZE = 256 -IMAGE_PATCH_SIZE = 128 -TARGET_PATCH_SIZE = 64 +# IMAGE_PATCH_SIZE = 512 +# TARGET_PATCH_SIZE = 128 def patchify(imgs, patch_size, target_size): @@ -29,12 +28,94 @@ def patchify(imgs, patch_size, target_size): return patches +def get_num_target_patches(img_height, img_width, patch_size, target_size): + margin = (patch_size - target_size) // 2 + pad_right = patch_size - (img_width % patch_size) + patch_size - margin + pad_bottom = patch_size - (img_height % patch_size) + patch_size - margin + num_patches_w = int((img_width + pad_right - patch_size) / target_size) + 1 + num_patches_h = int((img_height + pad_bottom - patch_size) / target_size) + 1 + + return num_patches_h, num_patches_w + + +def get_target_patches(patches, patch_size, target_size): + """ + patches: (B, C, patch_size, patch_size) + out: ( + B, target_size, target_size, C + ) + """ + margin = (patch_size - target_size) // 2 + + return patches[ + :, :, margin: margin + target_size, margin: margin + target_size + ].permute([0, 2, 3, 1]) + + +def get_patch_index( + pixels_y, pixels_x, img_height, img_width, patch_size, target_patch_size +): + """Gets patch index that contains the given one pixel coordinate.""" + total_rows, total_cols = get_num_target_patches( + img_height, img_width, patch_size, target_patch_size + ) + patch_index = ( + pixels_y // target_patch_size) * total_cols + (pixels_x // target_patch_size) + + return patch_index + + +def get_patch_indices(pixel_coords, img_height, img_width, patch_size, target_patch_size): + """Gets patch indices that contains the given pixel coordinates.""" + total_rows, total_cols = get_num_target_patches( + img_height, img_width, patch_size, target_patch_size + ) + ys = pixel_coords[:, 0] + xs = pixel_coords[:, 1] + patch_indices = (ys // target_patch_size) * total_cols + (xs // target_patch_size) + + return patch_indices + + +def get_patch_position(pix_y, pix_x, target_patch_size): + """Gets patch position that contains the given pixel coordinates.""" + # patch_row = int(np.ceil(pix_y / TARGET_PATCH_SIZE)) + # patch_col = int(np.ceil(pix_x / TARGET_PATCH_SIZE)) + patch_row = pix_y // target_patch_size + patch_col = pix_x // target_patch_size + + return patch_row, patch_col + + +def is_image_rgb(image_data): + return image_data.shape[-1] == 3 + + +def is_stacked(image_data): + dims = len(image_data.shape) + if is_image_rgb(image_data): + return dims == 4 + return dims == 3 + + +def get_stack_sizes(image_data): + num_slices = 1 + img_height = image_data.shape[0] + img_width = image_data.shape[1] + if is_stacked(image_data): + num_slices = image_data.shape[0] + img_height = image_data.shape[1] + img_width = image_data.shape[2] + + return num_slices, img_height, img_width + + def unpatchify( embed_patches, img_h, img_w, patch_size, target_size ): """ patches: (B*N, C, patch_size, patch_size) - out: (B, embeding_size, H, W) + out: (B, embedding_size, H, W) """ bn, c, _, _ = embed_patches.shape margin = (patch_size - target_size) // 2 @@ -105,83 +186,3 @@ def unpatchify_np( ] = target_patches return padded_images[:, :, :img_h, :img_w] - - -def get_num_target_patches(img_height, img_width, patch_size, target_size): - margin = (patch_size - target_size) // 2 - pad_right = patch_size - (img_width % patch_size) + patch_size - margin - pad_bottom = patch_size - (img_height % patch_size) + patch_size - margin - num_patches_w = int((img_width + pad_right - patch_size) / target_size) + 1 - num_patches_h = int((img_height + pad_bottom - patch_size) / target_size) + 1 - - return num_patches_h, num_patches_w - - -def get_target_patches(patches, patch_size, target_size): - """ - patches: (B, C, patch_size, patch_size) - out: ( - B, target_size, target_size, C - ) - """ - margin = (patch_size - target_size) // 2 - - return patches[ - :, :, margin: margin + target_size, margin: margin + target_size - ].permute([0, 2, 3, 1]) - - -def get_patch_index(pixels_y, pixels_x, img_height, img_width, patch_size, target_size): - """Gets patch index that contains the given one pixel coordinate.""" - total_rows, total_cols = get_num_target_patches( - img_height, img_width, patch_size, target_size - ) - patch_index = ( - pixels_y // TARGET_PATCH_SIZE) * total_cols + (pixels_x // TARGET_PATCH_SIZE) - - return patch_index - - -def get_patch_indices(pixel_coords, img_height, img_width, patch_size, target_size): - """Gets patch indices that contains the given pixel coordinates.""" - total_rows, total_cols = get_num_target_patches( - img_height, img_width, patch_size, target_size - ) - ys = pixel_coords[:, 0] - xs = pixel_coords[:, 1] - patch_indices = (ys // TARGET_PATCH_SIZE) * total_cols + (xs // TARGET_PATCH_SIZE) - - return patch_indices - - -def get_patch_position(pix_y, pix_x): - """Gets patch position that contains the given pixel coordinates.""" - # patch_row = int(np.ceil(pix_y / TARGET_PATCH_SIZE)) - # patch_col = int(np.ceil(pix_x / TARGET_PATCH_SIZE)) - patch_row = pix_y // TARGET_PATCH_SIZE - patch_col = pix_x // TARGET_PATCH_SIZE - - return patch_row, patch_col - - -def is_image_rgb(image_data): - return image_data.shape[-1] == 3 - - -def is_stacked(image_data): - dims = len(image_data.shape) - if is_image_rgb(image_data): - return dims == 4 - return dims == 3 - - -def get_stack_sizes(image_data): - num_slices = 1 - img_height = image_data.shape[0] - img_width = image_data.shape[1] - if is_stacked(image_data): - num_slices = image_data.shape[0] - img_height = image_data.shape[1] - img_width = image_data.shape[2] - - return num_slices, img_height, img_width diff --git a/src/napari_sam_labeling_tools/utils/extract.py b/src/napari_sam_labeling_tools/utils/extract.py index 6e8a45d..164b039 100644 --- a/src/napari_sam_labeling_tools/utils/extract.py +++ b/src/napari_sam_labeling_tools/utils/extract.py @@ -6,8 +6,6 @@ from torchvision import transforms from .data import ( - IMAGE_PATCH_SIZE, - TARGET_PATCH_SIZE, patchify, get_target_patches, get_num_target_patches, is_image_rgb, ) @@ -17,7 +15,22 @@ ) -def get_sam_embeddings_for_slice(sam_encoder, device, image, storage_group: h5py.Group): +def get_patch_sizes(img_height, img_width): + patch_size = 512 + target_patch_size = 128 + img_min_dim = min(img_height, img_width) + while img_min_dim - patch_size < 100: + patch_size = patch_size // 2 + if patch_size < 256: + target_patch_size = patch_size // 2 + + return patch_size, target_patch_size + + +def get_sam_embeddings_for_slice( + image, patch_size, target_patch_size, + sam_encoder, device, storage_group: h5py.Group +): """get sam features for one slice.""" img_height, img_width = image.shape[:2] # image to torch tensor @@ -32,7 +45,7 @@ def get_sam_embeddings_for_slice(sam_encoder, device, image, storage_group: h5py # to resize encoder output back to the input patch size embedding_transform = transforms.Compose([ transforms.Resize( - (IMAGE_PATCH_SIZE, IMAGE_PATCH_SIZE), + (patch_size, patch_size), interpolation=transforms.InterpolationMode.BICUBIC, antialias=True ), @@ -40,18 +53,18 @@ def get_sam_embeddings_for_slice(sam_encoder, device, image, storage_group: h5py ]) # get sam encoder output for image patches - data_patches = patchify(img_data, IMAGE_PATCH_SIZE, TARGET_PATCH_SIZE) + data_patches = patchify(img_data, patch_size, target_patch_size) num_patches = len(data_patches) batch_size = 10 num_batches = int(np.ceil(num_patches / batch_size)) # prepare storage for the slice embeddings target_patch_rows, target_patch_cols = get_num_target_patches( - img_height, img_width, IMAGE_PATCH_SIZE, TARGET_PATCH_SIZE + img_height, img_width, patch_size, target_patch_size ) total_channels = ENCODER_OUT_CHANNELS + EMBED_PATCH_CHANNELS dataset = storage_group.create_dataset( "sam", shape=( - num_patches, TARGET_PATCH_SIZE, TARGET_PATCH_SIZE, total_channels + num_patches, target_patch_size, target_patch_size, total_channels ) ) @@ -74,11 +87,11 @@ def get_sam_embeddings_for_slice(sam_encoder, device, image, storage_group: h5py start: start + num_out, :, :, :ENCODER_OUT_CHANNELS ] = get_target_patches( embedding_transform(output.cpu()), - IMAGE_PATCH_SIZE, TARGET_PATCH_SIZE + patch_size, target_patch_size ) dataset[ start: start + num_out, :, :, ENCODER_OUT_CHANNELS: ] = get_target_patches( embedding_transform(embed_output.cpu()), - IMAGE_PATCH_SIZE, TARGET_PATCH_SIZE + patch_size, target_patch_size )