From d91838e6798331224d71996202fb4c6570ffa713 Mon Sep 17 00:00:00 2001 From: Haibao Tang Date: Sun, 19 May 2024 23:40:04 -0700 Subject: [PATCH] sort by area descending --- jcvi/graphics/grabseeds.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/jcvi/graphics/grabseeds.py b/jcvi/graphics/grabseeds.py index 800ef001..74cd209d 100644 --- a/jcvi/graphics/grabseeds.py +++ b/jcvi/graphics/grabseeds.py @@ -155,7 +155,7 @@ def calibrate(self, pixel_cm_ratio: float, tr: np.ndarray): self.calibrated = True -def sam(img: np.ndarray) -> List[dict]: +def sam(img: np.ndarray, checkpoint: str) -> List[dict]: """ Use Segment Anything Model (SAM) to segment objects. """ @@ -166,11 +166,10 @@ def sam(img: np.ndarray) -> List[dict]: sys.exit(1) model_type = "vit_h" - checkpoint = "sam_vit_h_4b8939.pth" if not op.exists(checkpoint): - checkpoint_dir = input("Enter the path to `sam_vit_h_4b8939.pth`: ") - checkpoint = op.join(checkpoint_dir, "sam_vit_h_4b8939.pth") - assert op.exists(checkpoint), f"File `{checkpoint}` not found" + raise AssertionError( + f"File `{checkpoint}` not found, please specify --sam-checkpoint" + ) sam = sam_model_registry[model_type](checkpoint=checkpoint) logger.info("Using SAM model `%s` (%s)", model_type, checkpoint) mask_generator = SamAutomaticMaskGenerator(sam) @@ -387,6 +386,9 @@ def add_seeds_options(p, args): g3.add_argument( "--border", default=5, type=int, help="Remove image border of certain pixels" ) + g3.add_argument( + "--sam-checkpoint", default="sam_vit_h_4b8939.pth", help="SAM checkpoint file" + ) g4 = p.add_argument_group("Output") g4.add_argument("--calibrate", help="JSON file to correct distance and color") @@ -715,7 +717,7 @@ def seeds(args): elif ff == "sobel": edges = sobel(img_gray) if ff == "sam": - masks = sam(img) + masks = sam(img, opts.sam_checkpoint) filtered_masks = [ mask for mask in masks if min_size <= mask["area"] <= max_size ] @@ -779,6 +781,7 @@ def seeds(args): # Calculate region properties rp = regionprops(labels) rp = [x for x in rp if min_size <= x.area <= max_size] + rp.sort(key=lambda x: x.area, reverse=True) nb_labels = len(rp) logger.debug("A total of %d objects identified.", nb_labels) objects = []