-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update model predictor adaptor for bioimage models
- Loading branch information
Showing
3 changed files
with
138 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from micro_sam import model_zoo | ||
|
||
|
||
def main(): | ||
parser = model_zoo._get_modelzoo_parser() | ||
args = parser.parse_args() | ||
|
||
model_zoo.get_modelzoo_yaml( | ||
image_path=args.input_path, | ||
box_prompts=None, | ||
model_type=args.model_type, | ||
output_path=args.output_path, | ||
doc_path=args.doc_path | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from typing import Optional | ||
|
||
import torch | ||
|
||
from segment_anything.predictor import SamPredictor | ||
|
||
|
||
class PredictorAdaptor(SamPredictor): | ||
"""Wrapper around the SamPredictor to be used by BioImage.IO model format. | ||
This model supports the same functionality as SamPredictor and can provide mask segmentations | ||
from box, point or mask input prompts. | ||
""" | ||
def __call__( | ||
self, | ||
input_image: torch.Tensor, | ||
image_embeddings: Optional[torch.Tensor] = None, | ||
box_prompts: Optional[torch.Tensor] = None | ||
) -> torch.Tensor: | ||
if self.is_image_set and image_embeddings is None: # we have embeddings set and not passed | ||
pass # do nothing | ||
elif self.is_image_set and image_embeddings is not None: | ||
raise NotImplementedError # TODO: replace the image embeedings | ||
elif image_embeddings is not None: | ||
pass # TODO set the image embeddings | ||
# self.features = image_embeddings | ||
elif not self.is_image_set: | ||
self.set_torch_image(input_image) # compute the image embeddings | ||
|
||
instance_segmentation, _, _ = self.predict_torch( | ||
point_coords=None, | ||
point_labels=None, | ||
boxes=box_prompts, | ||
multimask_output=False | ||
) | ||
# TODO get the image embeddings via image_embeddings = self.features | ||
# and return them | ||
return instance_segmentation |