Skip to content

Commit

Permalink
sort by area descending
Browse files Browse the repository at this point in the history
  • Loading branch information
tanghaibao committed May 20, 2024
1 parent 1c389a0 commit d91838e
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions jcvi/graphics/grabseeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def calibrate(self, pixel_cm_ratio: float, tr: np.ndarray):
self.calibrated = True


def sam(img: np.ndarray) -> List[dict]:
def sam(img: np.ndarray, checkpoint: str) -> List[dict]:
"""
Use Segment Anything Model (SAM) to segment objects.
"""
Expand All @@ -166,11 +166,10 @@ def sam(img: np.ndarray) -> List[dict]:
sys.exit(1)

model_type = "vit_h"
checkpoint = "sam_vit_h_4b8939.pth"
if not op.exists(checkpoint):
checkpoint_dir = input("Enter the path to `sam_vit_h_4b8939.pth`: ")
checkpoint = op.join(checkpoint_dir, "sam_vit_h_4b8939.pth")
assert op.exists(checkpoint), f"File `{checkpoint}` not found"
raise AssertionError(
f"File `{checkpoint}` not found, please specify --sam-checkpoint"
)
sam = sam_model_registry[model_type](checkpoint=checkpoint)
logger.info("Using SAM model `%s` (%s)", model_type, checkpoint)
mask_generator = SamAutomaticMaskGenerator(sam)
Expand Down Expand Up @@ -387,6 +386,9 @@ def add_seeds_options(p, args):
g3.add_argument(
"--border", default=5, type=int, help="Remove image border of certain pixels"
)
g3.add_argument(
"--sam-checkpoint", default="sam_vit_h_4b8939.pth", help="SAM checkpoint file"
)

g4 = p.add_argument_group("Output")
g4.add_argument("--calibrate", help="JSON file to correct distance and color")
Expand Down Expand Up @@ -715,7 +717,7 @@ def seeds(args):
elif ff == "sobel":
edges = sobel(img_gray)
if ff == "sam":
masks = sam(img)
masks = sam(img, opts.sam_checkpoint)
filtered_masks = [
mask for mask in masks if min_size <= mask["area"] <= max_size
]
Expand Down Expand Up @@ -779,6 +781,7 @@ def seeds(args):
# Calculate region properties
rp = regionprops(labels)
rp = [x for x in rp if min_size <= x.area <= max_size]
rp.sort(key=lambda x: x.area, reverse=True)
nb_labels = len(rp)
logger.debug("A total of %d objects identified.", nb_labels)
objects = []
Expand Down

0 comments on commit d91838e

Please sign in to comment.