From 3025b6c87f7df0a4a3424793d1901feaeda88f51 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sun, 10 Sep 2023 12:20:36 +0200 Subject: [PATCH 01/11] Fix CLI scripts in setup.cfg --- setup.cfg | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/setup.cfg b/setup.cfg index 0119a5d2..362c83b0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -5,7 +5,7 @@ description = SegmentAnything For Microscopy long_description = file: README.md long_description_content_type = text/markdown url = https://github.com/computational-cell-analytics/micro-sam -author = Anwai Archit, Constantin Pape +author = Anwai Archit, Paul Hilt, Genevieve Buckley, Constantin Pape author_email = yourname@example.com license = MIT license_files = LICENSE @@ -42,15 +42,14 @@ where = . [options.entry_points] napari.manifest = micro-sam = micro_sam:napari.yaml +console_scripts = + micro_sam.annotator = micro_sam.sam_annotator.annotator:main + micro_sam.annotator_2d = micro_sam.sam_annotator.annotator_2d:main + micro_sam.annotator_3d = micro_sam.sam_annotator.annotator_3d:main + micro_sam.annotator_tracking = micro_sam.sam_annotator.annotator_tracking:main + micro_sam.image_series_annotator = micro_sam.sam_annotator.image_series_annotator:main + micro_sam.precompute_embeddings = micro_sam.util:main # make sure it gets included in your package [options.package_data] * = *.yaml - -[project.scripts] -micro_sam.annotator = "micro_sam.sam_annotator.annotator:main" -micro_sam.annotator_2d = "micro_sam.sam_annotator.annotator_2d:main" -micro_sam.annotator_3d = "micro_sam.sam_annotator.annotator_3d:main" -micro_sam.annotator_tracking = "micro_sam.sam_annotator.annotator_tracking:main" -micro_sam.image_series_annotator = "micro_sam.sam_annotator.image_series_annotator:main" -micro_sam.precompute_embeddings = "micro_sam.util:main" From 04fd47f9a0b0f03ecc58d9de237432daffc1b9b2 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Thu, 14 Sep 2023 10:18:30 +0200 Subject: [PATCH 02/11] Update Syntax and Minor Fixes --- micro_sam/training/sam_trainer.py | 28 ++++++++-------------------- 1 file changed, 8 insertions(+), 20 deletions(-) diff --git a/micro_sam/training/sam_trainer.py b/micro_sam/training/sam_trainer.py index 9a30554b..f15b5691 100644 --- a/micro_sam/training/sam_trainer.py +++ b/micro_sam/training/sam_trainer.py @@ -131,8 +131,7 @@ def _get_net_loss(self, batched_outputs, y, sampled_ids): # outer loop is over the batch (different image/patch predictions) for m_, y_, ids_, predicted_iou_ in zip(masks, y, sampled_ids, predicted_iou_values): - per_object_dice_scores = [] - per_object_iou_scores = [] + per_object_dice_scores, per_object_iou_scores = [], [] # inner loop is over the channels, this corresponds to the different predicted objects for i, (predicted_obj, predicted_iou) in enumerate(zip(m_, predicted_iou_)): @@ -157,7 +156,7 @@ def _get_net_loss(self, batched_outputs, y, sampled_ids): return loss, mask_loss, iou_regression_loss, mean_model_iou def _postprocess_outputs(self, masks): - """ masks look like -> (B, 1, X, Y) + """ "masks" look like -> (B, 1, X, Y) where, B is the number of objects, (X, Y) is the input image shape """ instance_labels = [] @@ -174,8 +173,7 @@ def _get_val_metric(self, batched_outputs, sampled_binary_y): masks = [m["masks"] for m in batched_outputs] pred_labels = self._postprocess_outputs(masks) - # we do the condition below to adapt w.r.t. the multimask output - # to select the "objectively" best response + # we do the condition below to adapt w.r.t. the multimask output to select the "objectively" best response if pred_labels.dim() == 5: metric = min([self.metric(pred_labels[:, :, i, :, :], sampled_binary_y.to(self.device)) for i in range(pred_labels.shape[2])]) @@ -192,10 +190,7 @@ def _update_masks(self, batched_inputs, y, sampled_binary_y, sampled_ids, num_su input_images = torch.stack([self.model.preprocess(x=x["image"].to(self.device)) for x in batched_inputs], dim=0) image_embeddings = self.model.image_embeddings_oft(input_images) - loss = 0.0 - mask_loss = 0.0 - iou_regression_loss = 0.0 - mean_model_iou = 0.0 + loss, mask_loss, iou_regression_loss, mean_model_iou = 0.0, 0.0, 0.0, 0.0 # this loop takes care of the idea of sub-iterations, i.e. the number of times we iterate over each batch for i in range(0, num_subiter): @@ -219,9 +214,7 @@ def _update_masks(self, batched_inputs, y, sampled_binary_y, sampled_ids, num_su mask, l_mask = [], [] for _m, _l, _iou in zip(m["masks"], m["low_res_masks"], m["iou_predictions"]): best_iou_idx = torch.argmax(_iou) - - best_mask, best_logits = _m[best_iou_idx], _l[best_iou_idx] - best_mask, best_logits = best_mask[None], best_logits[None] + best_mask, best_logits = _m[best_iou_idx][None], _l[best_iou_idx][None] mask.append(self._sigmoid(best_mask)) l_mask.append(best_logits) @@ -320,8 +313,7 @@ def _train_epoch_impl(self, progress, forward_context, backprop): self.optimizer.zero_grad() with forward_context(): - n_samples = self.n_objects_per_batch - n_samples = self._update_samples_for_gt_instances(y, n_samples) + n_samples = self._update_samples_for_gt_instances(y, self.n_objects_per_batch) n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices(self._iteration) @@ -375,16 +367,12 @@ def _train_epoch_impl(self, progress, forward_context, backprop): def _validate_impl(self, forward_context): self.model.eval() - metric_val = 0.0 - loss_val = 0.0 - model_iou_val = 0.0 - val_iteration = 0 + metric_val, loss_val, model_iou_val, val_iteration = 0.0, 0.0, 0.0, 0.0 with torch.no_grad(): for x, y in self.val_loader: with forward_context(): - n_samples = self.n_objects_per_batch - n_samples = self._update_samples_for_gt_instances(y, n_samples) + n_samples = self._update_samples_for_gt_instances(y, self.n_objects_per_batch) (n_pos, n_neg, get_boxes, multimask_output) = self._get_prompt_and_multimasking_choices_for_val(val_iteration) From f7d6b12fbca806bd5396775f5abdaee2826948e7 Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Mon, 18 Sep 2023 13:48:47 +0200 Subject: [PATCH 03/11] Replace Point-Updates with IterativePromptGenerator in Trainer --- micro_sam/prompt_generators.py | 12 +++---- micro_sam/training/sam_trainer.py | 58 +++++-------------------------- 2 files changed, 12 insertions(+), 58 deletions(-) diff --git a/micro_sam/prompt_generators.py b/micro_sam/prompt_generators.py index 51849e3e..37f427cf 100644 --- a/micro_sam/prompt_generators.py +++ b/micro_sam/prompt_generators.py @@ -197,7 +197,7 @@ class IterativePromptGenerator: """ def _get_positive_points(self, pos_region, overlap_region): positive_locations = [torch.where(pos_reg) for pos_reg in pos_region] - # we may have objects withput a positive region (= missing true foreground) + # we may have objects without a positive region (= missing true foreground) # in this case we just sample a point where the model was already correct positive_locations = [ torch.where(ovlp_reg) if len(pos_loc[0]) == 0 else pos_loc @@ -258,16 +258,12 @@ def __call__( self, gt: torch.Tensor, object_mask: torch.Tensor, - current_points: torch.Tensor, - current_labels: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Generate the prompts for each object iteratively in the segmentation. Args: The groundtruth segmentation. The predicted objects. - The current points. - Thr current labels. Returns: The updated point prompt coordinates. @@ -278,7 +274,7 @@ def __call__( true_object = gt.to(device) expected_diff = (object_mask - true_object) - neg_region = (expected_diff == 1).to(torch.float) + neg_region = (expected_diff == 1).to(torch.float32) pos_region = (expected_diff == -1) overlap_region = torch.logical_and(object_mask == 1, true_object == 1).to(torch.float32) @@ -290,7 +286,7 @@ def __call__( neg_coordinates = torch.tensor(neg_coordinates)[:, None] pos_labels, neg_labels = torch.tensor(pos_labels)[:, None], torch.tensor(neg_labels)[:, None] - net_coords = torch.cat([current_points, pos_coordinates, neg_coordinates], dim=1) - net_labels = torch.cat([current_labels, pos_labels, neg_labels], dim=1) + net_coords = torch.cat([pos_coordinates, neg_coordinates], dim=1) + net_labels = torch.cat([pos_labels, neg_labels], dim=1) return net_coords, net_labels diff --git a/micro_sam/training/sam_trainer.py b/micro_sam/training/sam_trainer.py index f15b5691..3dd11b37 100644 --- a/micro_sam/training/sam_trainer.py +++ b/micro_sam/training/sam_trainer.py @@ -237,57 +237,15 @@ def _update_masks(self, batched_inputs, y, sampled_binary_y, sampled_ids, num_su def _get_updated_points_per_mask_per_subiter(self, masks, sampled_binary_y, batched_inputs, logits_masks): # here, we get the pair-per-batch of predicted and true elements (and also the "batched_inputs") for x1, x2, _inp, logits in zip(masks, sampled_binary_y, batched_inputs, logits_masks): - net_coords, net_labels = [], [] - # here, we get each object in the pairs and do the point choices per-object - for pred_obj, true_obj in zip(x1, x2): - true_obj = true_obj.to(self.device) - - expected_diff = (pred_obj - true_obj) - - neg_region = (expected_diff == 1).to(torch.float32) - pos_region = (expected_diff == -1) - overlap_region = torch.logical_and(pred_obj == 1, true_obj == 1).to(torch.float32) - - # POSITIVE POINTS - tmp_pos_loc = torch.where(pos_region) - if torch.stack(tmp_pos_loc).shape[-1] == 0: - tmp_pos_loc = torch.where(overlap_region) - - pos_index = np.random.choice(len(tmp_pos_loc[1])) - pos_coordinates = int(tmp_pos_loc[1][pos_index]), int(tmp_pos_loc[2][pos_index]) - pos_coordinates = pos_coordinates[::-1] - pos_labels = 1 - - # NEGATIVE POINTS - tmp_neg_loc = torch.where(neg_region) - if torch.stack(tmp_neg_loc).shape[-1] == 0: - tmp_true_loc = torch.where(true_obj) - x_coords, y_coords = tmp_true_loc[1], tmp_true_loc[2] - bbox = torch.stack([torch.min(x_coords), torch.min(y_coords), - torch.max(x_coords) + 1, torch.max(y_coords) + 1]) - bbox_mask = torch.zeros_like(true_obj).squeeze(0) - bbox_mask[bbox[0]:bbox[2], bbox[1]:bbox[3]] = 1 - bbox_mask = bbox_mask[None].to(self.device) - - dilated_bbox_mask = dilation(bbox_mask[None], torch.ones(3, 3).to(self.device)).squeeze(0) - background_mask = abs(dilated_bbox_mask - true_obj) - tmp_neg_loc = torch.where(background_mask) - - neg_index = np.random.choice(len(tmp_neg_loc[1])) - neg_coordinates = int(tmp_neg_loc[1][neg_index]), int(tmp_neg_loc[2][neg_index]) - neg_coordinates = neg_coordinates[::-1] - neg_labels = 0 - - net_coords.append([pos_coordinates, neg_coordinates]) - net_labels.append([pos_labels, neg_labels]) - - if "point_labels" in _inp.keys(): - updated_point_coords = torch.cat([_inp["point_coords"], torch.tensor(net_coords)], dim=1) - updated_point_labels = torch.cat([_inp["point_labels"], torch.tensor(net_labels)], dim=1) - else: - updated_point_coords = torch.tensor(net_coords) - updated_point_labels = torch.tensor(net_labels) + from micro_sam.prompt_generators import IterativePromptGenerator + iterative_prompter = IterativePromptGenerator() + net_coords, net_labels = iterative_prompter(x2, x1) + + updated_point_coords = torch.cat([_inp["point_coords"], net_coords], dim=1) \ + if "point_coords" in _inp.keys() else net_coords + updated_point_labels = torch.cat([_inp["point_labels"], net_labels], dim=1) \ + if "point_labels" in _inp.keys() else net_labels _inp["point_coords"] = updated_point_coords _inp["point_labels"] = updated_point_labels From 3e0e89338e57ae5b190ad7afa86e91a8759af50b Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Mon, 18 Sep 2023 18:16:28 +0200 Subject: [PATCH 04/11] Update Imports --- micro_sam/training/sam_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/micro_sam/training/sam_trainer.py b/micro_sam/training/sam_trainer.py index 3dd11b37..a78af1a0 100644 --- a/micro_sam/training/sam_trainer.py +++ b/micro_sam/training/sam_trainer.py @@ -6,10 +6,11 @@ import torch import torch_em -from kornia.morphology import dilation from torchvision.utils import make_grid from torch_em.trainer.logger_base import TorchEmLogger +from ..prompt_generators import IterativePromptGenerator + class SamTrainer(torch_em.trainer.DefaultTrainer): """Trainer class for training the Segment Anything model. @@ -238,7 +239,6 @@ def _get_updated_points_per_mask_per_subiter(self, masks, sampled_binary_y, batc # here, we get the pair-per-batch of predicted and true elements (and also the "batched_inputs") for x1, x2, _inp, logits in zip(masks, sampled_binary_y, batched_inputs, logits_masks): # here, we get each object in the pairs and do the point choices per-object - from micro_sam.prompt_generators import IterativePromptGenerator iterative_prompter = IterativePromptGenerator() net_coords, net_labels = iterative_prompter(x2, x1) From 686f93ae30aedf1505baae6e07f4e96014f44200 Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Tue, 19 Sep 2023 17:12:22 +0200 Subject: [PATCH 05/11] Update Dilation - by shifting the pixels --- micro_sam/prompt_generators.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/micro_sam/prompt_generators.py b/micro_sam/prompt_generators.py index 37f427cf..47b7e3a7 100644 --- a/micro_sam/prompt_generators.py +++ b/micro_sam/prompt_generators.py @@ -10,7 +10,6 @@ from scipy.ndimage import binary_dilation import torch -from kornia.morphology import dilation class PointAndBoxPromptGenerator: @@ -230,13 +229,13 @@ def _get_negative_points(self, negative_region_batched, true_object_batched, gt_ bbox = torch.stack([torch.min(x_coords), torch.min(y_coords), torch.max(x_coords) + 1, torch.max(y_coords) + 1]) bbox_mask = torch.zeros_like(true_object).squeeze(0) - bbox_mask[bbox[0]:bbox[2], bbox[1]:bbox[3]] = 1 + + custom_df = 3 # custom dilation factor to perform dilation by expanding the pixels of bbox + bbox_mask[max(bbox[0] - custom_df, 0): min(bbox[2] + custom_df, gt.shape[-1]), + max(bbox[1] - custom_df, 0): min(bbox[3] + custom_df, gt.shape[-2])] = 1 bbox_mask = bbox_mask[None].to(device) - # NOTE: FIX: here we add dilation to the bbox because in some case we couldn't find objects at all - # TODO: just expand the pixels of bbox - dilated_bbox_mask = dilation(bbox_mask[None], torch.ones(3, 3).to(device)).squeeze(0) - background_mask = abs(dilated_bbox_mask - true_object) + background_mask = abs(bbox_mask - true_object) tmp_neg_loc = torch.where(background_mask) # there is a chance that the object is small to not return a decent-sized bounding box From 82a06360eb0d312ba1064768018e69abf52f5828 Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Wed, 20 Sep 2023 13:09:22 +0200 Subject: [PATCH 06/11] Update Minor Fixes --- micro_sam/prompt_generators.py | 2 +- micro_sam/training/sam_trainer.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/micro_sam/prompt_generators.py b/micro_sam/prompt_generators.py index 47b7e3a7..5669804d 100644 --- a/micro_sam/prompt_generators.py +++ b/micro_sam/prompt_generators.py @@ -235,7 +235,7 @@ def _get_negative_points(self, negative_region_batched, true_object_batched, gt_ max(bbox[1] - custom_df, 0): min(bbox[3] + custom_df, gt.shape[-2])] = 1 bbox_mask = bbox_mask[None].to(device) - background_mask = abs(bbox_mask - true_object) + background_mask = torch.abs(bbox_mask - true_object) tmp_neg_loc = torch.where(background_mask) # there is a chance that the object is small to not return a decent-sized bounding box diff --git a/micro_sam/training/sam_trainer.py b/micro_sam/training/sam_trainer.py index a78af1a0..538c2eb7 100644 --- a/micro_sam/training/sam_trainer.py +++ b/micro_sam/training/sam_trainer.py @@ -38,6 +38,7 @@ def __init__( n_objects_per_batch: Optional[int] = None, mse_loss: torch.nn.Module = torch.nn.MSELoss(), _sigmoid: torch.nn.Module = torch.nn.Sigmoid(), + prompt_generator=IterativePromptGenerator(), **kwargs ): super().__init__(**kwargs) @@ -46,6 +47,7 @@ def __init__( self._sigmoid = _sigmoid self.n_objects_per_batch = n_objects_per_batch self.n_sub_iteration = n_sub_iteration + self.prompt_generator = prompt_generator self._kwargs = kwargs def _get_prompt_and_multimasking_choices(self, current_iteration): @@ -239,8 +241,7 @@ def _get_updated_points_per_mask_per_subiter(self, masks, sampled_binary_y, batc # here, we get the pair-per-batch of predicted and true elements (and also the "batched_inputs") for x1, x2, _inp, logits in zip(masks, sampled_binary_y, batched_inputs, logits_masks): # here, we get each object in the pairs and do the point choices per-object - iterative_prompter = IterativePromptGenerator() - net_coords, net_labels = iterative_prompter(x2, x1) + net_coords, net_labels = self.prompt_generator(x2, x1) updated_point_coords = torch.cat([_inp["point_coords"], net_coords], dim=1) \ if "point_coords" in _inp.keys() else net_coords @@ -256,9 +257,8 @@ def _get_updated_points_per_mask_per_subiter(self, masks, sampled_binary_y, batc # def _update_samples_for_gt_instances(self, y, n_samples): - num_instances_gt = [len(torch.unique(_y)) for _y in y] - if n_samples > min(num_instances_gt): - n_samples = min(num_instances_gt) - 1 + num_instances_gt = torch.amax(y, dim=(1, 2, 3)) + n_samples = min(num_instances_gt) if n_samples > min(num_instances_gt) else n_samples return n_samples def _train_epoch_impl(self, progress, forward_context, backprop): From c3b86d0b6320e663951b16088f302b1b3c1d62d4 Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Wed, 20 Sep 2023 13:11:51 +0200 Subject: [PATCH 07/11] Refactor device call in training script --- finetuning/livecell_finetuning.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/finetuning/livecell_finetuning.py b/finetuning/livecell_finetuning.py index 3521db4b..46a1cef5 100644 --- a/finetuning/livecell_finetuning.py +++ b/finetuning/livecell_finetuning.py @@ -32,11 +32,12 @@ def get_dataloaders(patch_shape, data_path, cell_type=None): def finetune_livecell(args): """Example code for finetuning SAM on LiveCELL""" + # override this (below) if you have some more complex set-up and need to specify the exact gpu + device = "cuda" if torch.cuda.is_available() else "cpu" # training settings: model_type = args.model_type checkpoint_path = None # override this to start training from a custom checkpoint - device = "cuda" # override this if you have some more complex set-up and need to specify the exact gpu patch_shape = (520, 740) # the patch shape for training n_objects_per_batch = 25 # this is the number of objects per batch that will be sampled @@ -105,7 +106,7 @@ def main(): ) parser.add_argument( "--export_path", "-e", - help="Where to export the finetuned model to. The exported model can be use din the annotation tools." + help="Where to export the finetuned model to. The exported model can be used in the annotation tools." ) args = parser.parse_args() finetune_livecell(args) From 158bc2555105fb3aab86c71f5917fd3158800eb5 Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Wed, 20 Sep 2023 14:03:43 +0200 Subject: [PATCH 08/11] Fix Pixel Dilation --- micro_sam/prompt_generators.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/micro_sam/prompt_generators.py b/micro_sam/prompt_generators.py index 5669804d..c17c0524 100644 --- a/micro_sam/prompt_generators.py +++ b/micro_sam/prompt_generators.py @@ -231,8 +231,8 @@ def _get_negative_points(self, negative_region_batched, true_object_batched, gt_ bbox_mask = torch.zeros_like(true_object).squeeze(0) custom_df = 3 # custom dilation factor to perform dilation by expanding the pixels of bbox - bbox_mask[max(bbox[0] - custom_df, 0): min(bbox[2] + custom_df, gt.shape[-1]), - max(bbox[1] - custom_df, 0): min(bbox[3] + custom_df, gt.shape[-2])] = 1 + bbox_mask[max(bbox[0] - custom_df, 0): min(bbox[2] + custom_df, gt.shape[-2]), + max(bbox[1] - custom_df, 0): min(bbox[3] + custom_df, gt.shape[-1])] = 1 bbox_mask = bbox_mask[None].to(device) background_mask = torch.abs(bbox_mask - true_object) From 8a577d45c5cf93e62f5bd3786123a0f0be79bffa Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Mon, 25 Sep 2023 15:02:49 +0200 Subject: [PATCH 09/11] Fix CLI for embedding precomputation --- micro_sam/precompute_state.py | 6 +++--- setup.cfg | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/micro_sam/precompute_state.py b/micro_sam/precompute_state.py index 6844c0f9..8d4eca20 100644 --- a/micro_sam/precompute_state.py +++ b/micro_sam/precompute_state.py @@ -101,7 +101,7 @@ def precompute_state( model_type: str = util._DEFAULT_MODEL, checkpoint_path: Optional[Union[os.PathLike, str]] = None, key: Optional[str] = None, - ndim: Union[int] = None, + ndim: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, precompute_amg_state: bool = False, @@ -158,8 +158,8 @@ def main(): parser.add_argument( "--halo", nargs="+", type=int, help="The halo for using tiled prediction", default=None ) - parser.add_argument("-n", "--ndim") - parser.add_argument("-p", "--precompute_amg_state") + parser.add_argument("-n", "--ndim", type=int) + parser.add_argument("-p", "--precompute_amg_state", action="store_true") args = parser.parse_args() precompute_state( diff --git a/setup.cfg b/setup.cfg index 362c83b0..d92cec09 100644 --- a/setup.cfg +++ b/setup.cfg @@ -48,7 +48,7 @@ console_scripts = micro_sam.annotator_3d = micro_sam.sam_annotator.annotator_3d:main micro_sam.annotator_tracking = micro_sam.sam_annotator.annotator_tracking:main micro_sam.image_series_annotator = micro_sam.sam_annotator.image_series_annotator:main - micro_sam.precompute_embeddings = micro_sam.util:main + micro_sam.precompute_embeddings = micro_sam.precompute_state:main # make sure it gets included in your package [options.package_data] From fd75f698da660c4021863a371f428d79231ff8fd Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Tue, 26 Sep 2023 09:17:26 +0200 Subject: [PATCH 10/11] Fix tuple annotations Use `Tuple` instead of `tuple` to be compatible with older python versions. --- micro_sam/prompt_based_segmentation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/micro_sam/prompt_based_segmentation.py b/micro_sam/prompt_based_segmentation.py index 3c4b848a..6ca0ba73 100644 --- a/micro_sam/prompt_based_segmentation.py +++ b/micro_sam/prompt_based_segmentation.py @@ -3,7 +3,7 @@ """ import warnings -from typing import Optional +from typing import Optional, Tuple import numpy as np from nifty.tools import blocking @@ -308,7 +308,7 @@ def segment_from_mask( use_box: bool = True, use_mask: bool = True, use_points: bool = False, - original_size: Optional[tuple[int, ...]] = None, + original_size: Optional[Tuple[int, ...]] = None, multimask_output: bool = False, return_all: bool = False, return_logits: bool = False, @@ -365,7 +365,7 @@ def segment_from_box( box: np.ndarray, image_embeddings: Optional[util.ImageEmbeddings] = None, i: Optional[int] = None, - original_size: Optional[tuple[int, ...]] = None, + original_size: Optional[Tuple[int, ...]] = None, multimask_output: bool = False, return_all: bool = False, ): @@ -405,7 +405,7 @@ def segment_from_box_and_points( labels: np.ndarray, image_embeddings: Optional[util.ImageEmbeddings] = None, i: Optional[int] = None, - original_size: Optional[tuple[int, ...]] = None, + original_size: Optional[Tuple[int, ...]] = None, multimask_output: bool = False, return_all: bool = False, ): From 6c5e146331519f88811aa661f1e33a56023fe3e4 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Tue, 26 Sep 2023 09:21:31 +0200 Subject: [PATCH 11/11] Fix more type annotations --- micro_sam/instance_segmentation.py | 2 +- micro_sam/prompt_generators.py | 8 ++++---- micro_sam/visualization.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/micro_sam/instance_segmentation.py b/micro_sam/instance_segmentation.py index f1c5e644..6609ec42 100644 --- a/micro_sam/instance_segmentation.py +++ b/micro_sam/instance_segmentation.py @@ -52,7 +52,7 @@ def __getitem__(self, index): def mask_data_to_segmentation( masks: List[Dict[str, Any]], - shape: tuple[int, ...], + shape: Tuple[int, ...], with_background: bool, min_object_size: int = 0, ) -> np.ndarray: diff --git a/micro_sam/prompt_generators.py b/micro_sam/prompt_generators.py index c17c0524..61b06d13 100644 --- a/micro_sam/prompt_generators.py +++ b/micro_sam/prompt_generators.py @@ -4,7 +4,7 @@ """ from collections.abc import Mapping -from typing import Optional, Tuple +from typing import List, Optional, Tuple import numpy as np from scipy.ndimage import binary_dilation @@ -155,10 +155,10 @@ def __call__( self, segmentation: np.ndarray, segmentation_id: int, - bbox_coordinates: Mapping[int, tuple], + bbox_coordinates: Mapping[int, Tuple], center_coordinates: Optional[Mapping[int, np.ndarray]] = None - ) -> tuple[ - Optional[list[tuple]], Optional[list[int]], Optional[list[tuple]], np.ndarray + ) -> Tuple[ + Optional[List[Tuple]], Optional[List[int]], Optional[List[Tuple]], np.ndarray ]: """Generate the prompts for one object in the segmentation. diff --git a/micro_sam/visualization.py b/micro_sam/visualization.py index c26b3b6c..0426e193 100644 --- a/micro_sam/visualization.py +++ b/micro_sam/visualization.py @@ -146,7 +146,7 @@ def _project_tiled_embeddings(image_embeddings): def project_embeddings_for_visualization( image_embeddings: ImageEmbeddings -) -> tuple[np.ndarray, tuple[float, ...]]: +) -> Tuple[np.ndarray, Tuple[float, ...]]: """Project image embeddings to pixel-wise PCA. Args: