From 80f1429107184a3165cf767496893f9ac39fa9d4 Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Tue, 19 Nov 2024 23:41:08 +0100 Subject: [PATCH 1/2] Add validation window for avoiding recomputing embeddings --- micro_sam/sam_annotator/_state.py | 3 +-- micro_sam/sam_annotator/_widgets.py | 20 +++++++++++++++++--- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/micro_sam/sam_annotator/_state.py b/micro_sam/sam_annotator/_state.py index ee42e4ebb..0dc9d5137 100644 --- a/micro_sam/sam_annotator/_state.py +++ b/micro_sam/sam_annotator/_state.py @@ -87,8 +87,7 @@ def initialize_predictor( # Initialize the model if necessary. if predictor is None: self.predictor, state = util.get_sam_model( - device=device, model_type=model_type, - checkpoint_path=checkpoint_path, return_state=True + device=device, model_type=model_type, checkpoint_path=checkpoint_path, return_state=True ) if prefer_decoder and "decoder_state" in state: self.decoder = get_decoder( diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index 4a58a42a7..5a3356525 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -940,8 +940,9 @@ def _create_settings_widget(self): self.device = "auto" device_options = ["auto"] + util._available_devices() - self.device_dropdown, layout = self._add_choice_param("device", self.device, device_options, - tooltip=get_tooltip("embedding", "device")) + self.device_dropdown, layout = self._add_choice_param( + "device", self.device, device_options, tooltip=get_tooltip("embedding", "device") + ) setting_values.layout().addLayout(layout) # Create UI for the save path. @@ -1062,6 +1063,16 @@ def _validate_inputs(self): # Otherwise we either don't have an embedding path or it is empty. We can proceed in both cases. return False + def _validate_existing_embeddings(self, state): + if state.image_embeddings is None: + return False + else: + val_results = { + "message_type": "info", + "message": "Embeddings have already been precomputed. Press OK to recompute the embeddings." + } + return _generate_message(val_results["message_type"], val_results["message"]) + def __call__(self, skip_validate=False): # Validate user inputs. if not skip_validate and self._validate_inputs(): @@ -1071,8 +1082,11 @@ def __call__(self, skip_validate=False): image = self.image_selection.get_value() # Update the image embeddings: - # Reset the state. state = AnnotatorState() + if self._validate_existing_embeddings(state): + return + + # Reset the state. state.reset_state() # Get image dimensions. From e66c5a9420cea68fd1989e43e9e4f486a9f62a77 Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Wed, 20 Nov 2024 18:38:33 +0100 Subject: [PATCH 2/2] Check for embeddings computed to avoid clearing layer objects --- micro_sam/sam_annotator/_annotator.py | 4 ++++ micro_sam/sam_annotator/_widgets.py | 2 ++ 2 files changed, 6 insertions(+) diff --git a/micro_sam/sam_annotator/_annotator.py b/micro_sam/sam_annotator/_annotator.py index fcd58cc44..0b800d844 100644 --- a/micro_sam/sam_annotator/_annotator.py +++ b/micro_sam/sam_annotator/_annotator.py @@ -153,6 +153,10 @@ def __init__(self, viewer: "napari.viewer.Viewer", ndim: int) -> None: def _update_image(self, segmentation_result=None): state = AnnotatorState() + # Whether embeddings already exist and avoid clearing objects in layers. + if state._embeddings_are_same: + return + # Update the image shape if it has changed. if state.image_shape != self._shape: if len(state.image_shape) != self._ndim: diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index 5a3356525..955aab52d 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -1084,8 +1084,10 @@ def __call__(self, skip_validate=False): # Update the image embeddings: state = AnnotatorState() if self._validate_existing_embeddings(state): + state._embeddings_are_same = True # Whether embeddings already exist to control existing objects in layers. return + state._embeddings_are_same = False # Reset the state. state.reset_state()