Skip to content

Commit

Permalink
Add --filter sam (#670)
Browse files Browse the repository at this point in the history
* Add sam

* Ask where the sam model is

* sort by area descending

* logger info

* don't check closed

* 0.05 to 0.5

* remove a logger info

* keep changing minsize and maxsize

* improve logger info
  • Loading branch information
tanghaibao authored May 20, 2024
1 parent 6f6d556 commit 621eb47
Showing 1 changed file with 100 additions and 27 deletions.
127 changes: 100 additions & 27 deletions jcvi/graphics/grabseeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,56 @@ def calibrate(self, pixel_cm_ratio: float, tr: np.ndarray):
self.calibrated = True


def sam(img: np.ndarray, checkpoint: str) -> List[dict]:
"""
Use Segment Anything Model (SAM) to segment objects.
"""
try:
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
except ImportError:
logger.fatal("segment_anything not installed. Please install it first.")
sys.exit(1)

model_type = "vit_h"
if not op.exists(checkpoint):
raise AssertionError(
f"File `{checkpoint}` not found, please specify --sam-checkpoint"
)
sam = sam_model_registry[model_type](checkpoint=checkpoint)
logger.info("Using SAM model `%s` (%s)", model_type, checkpoint)
mask_generator = SamAutomaticMaskGenerator(sam)
return mask_generator.generate(img)


def is_overlapping(mask1: dict, mask2: dict, threshold=0.5):
"""
Check if bounding boxes of mask1 and mask2 overlap more than the given
threshold.
"""
x1, y1, w1, h1 = mask1["bbox"]
x2, y2, w2, h2 = mask2["bbox"]
x_overlap = max(0, min(x1 + w1, x2 + w2) - max(x1, x2))
y_overlap = max(0, min(y1 + h1, y2 + h2) - max(y1, y2))
intersection = x_overlap * y_overlap
return intersection / min(w1 * h1, w2 * h2) > threshold


def deduplicate_masks(masks: List[dict], threshold=0.5):
"""
Deduplicate masks to retain only the foreground objects.
"""
masks_sorted = sorted(masks, key=lambda x: x["area"])
retained_masks = []

for mask in masks_sorted:
if not any(
is_overlapping(mask, retained_mask, threshold)
for retained_mask in retained_masks
):
retained_masks.append(mask)
return retained_masks


def rgb_to_triplet(rgb: str) -> RGBTuple:
"""
Convert RGB string to triplet.
Expand Down Expand Up @@ -295,12 +345,12 @@ def add_seeds_options(p, args):
g2 = p.add_argument_group("Object recognition")
g2.add_argument(
"--minsize",
default=0.05,
default=0.2,
type=float,
help="Min percentage of object to image",
)
g2.add_argument(
"--maxsize", default=50, type=float, help="Max percentage of object to image"
"--maxsize", default=20, type=float, help="Max percentage of object to image"
)
g2.add_argument(
"--count", default=100, type=int, help="Report max number of objects"
Expand All @@ -313,7 +363,7 @@ def add_seeds_options(p, args):
)

g3 = p.add_argument_group("De-noise")
valid_filters = ("canny", "roberts", "sobel", "otsu")
valid_filters = ("canny", "otsu", "roberts", "sam", "sobel")
g3.add_argument(
"--filter",
default="canny",
Expand All @@ -335,6 +385,9 @@ def add_seeds_options(p, args):
g3.add_argument(
"--border", default=5, type=int, help="Remove image border of certain pixels"
)
g3.add_argument(
"--sam-checkpoint", default="sam_vit_h_4b8939.pth", help="SAM checkpoint file"
)

g4 = p.add_argument_group("Output")
g4.add_argument("--calibrate", help="JSON file to correct distance and color")
Expand Down Expand Up @@ -647,39 +700,57 @@ def seeds(args):
_, (ax1, ax2, ax3, ax4) = plt.subplots(ncols=4, nrows=1, figsize=(iopts.w, iopts.h))
# Edge detection
img_gray = rgb2gray(img)
logger.debug("Running %s edge detection ...", ff)
w, h = img_gray.shape
canvas_size = w * h
min_size = int(round(canvas_size * opts.minsize / 100))
max_size = int(round(canvas_size * opts.maxsize / 100))

logger.debug("Running %s edge detection …", ff)
if ff == "canny":
edges = canny(img_gray, sigma=opts.sigma)
elif ff == "otsu":
thresh = threshold_otsu(img_gray)
edges = img_gray > thresh
elif ff == "roberts":
edges = roberts(img_gray)
elif ff == "sobel":
edges = sobel(img_gray)
elif ff == "otsu":
thresh = threshold_otsu(img_gray)
edges = img_gray > thresh
edges = clear_border(edges, buffer_size=opts.border)
selem = disk(kernel)
closed = closing(edges, selem) if kernel else edges
filled = binary_fill_holes(closed)

# Watershed algorithm
if opts.watershed:
distance = distance_transform_edt(filled)
local_maxi = peak_local_max(distance, threshold_rel=0.05, indices=False)
coordinates = peak_local_max(distance, threshold_rel=0.05)
markers, nmarkers = label(local_maxi, return_num=True)
logger.debug("Identified %d watershed markers", nmarkers)
labels = watershed(closed, markers, mask=filled)
if ff == "sam":
masks = sam(img, opts.sam_checkpoint)
filtered_masks = [
mask for mask in masks if min_size <= mask["area"] <= max_size
]
deduplicated_masks = deduplicate_masks(filtered_masks)
logger.info(
"SAM: %d (raw) → %d (size filtered) → %d (deduplicated)",
len(masks),
len(filtered_masks),
len(deduplicated_masks),
)
labels = np.zeros(img_gray.shape, dtype=int)
for i, mask in enumerate(deduplicated_masks):
labels[mask["segmentation"]] = i + 1
labels = clear_border(labels)
else:
labels = label(filled)
edges = clear_border(edges, buffer_size=opts.border)
selem = disk(kernel)
closed = closing(edges, selem) if kernel else edges
filled = binary_fill_holes(closed)

# Watershed algorithm
if opts.watershed:
distance = distance_transform_edt(filled)
local_maxi = peak_local_max(distance, threshold_rel=0.05, indices=False)
coordinates = peak_local_max(distance, threshold_rel=0.05)
markers, nmarkers = label(local_maxi, return_num=True)
logger.debug("Identified %d watershed markers", nmarkers)
labels = watershed(closed, markers, mask=filled)
else:
labels = label(filled)

# Object size filtering
w, h = img_gray.shape
canvas_size = w * h
min_size = int(round(canvas_size * opts.minsize / 100))
max_size = int(round(canvas_size * opts.maxsize / 100))
logger.debug(
"Find objects with pixels between %d (%d%%) and %d (%d%%)",
"Find objects with pixels between %d (%.2f%%) and %d (%d%%)",
min_size,
opts.minsize,
max_size,
Expand All @@ -694,7 +765,8 @@ def seeds(args):
if opts.watershed:
params += ", watershed"
ax2.set_title(f"Edge detection\n({params})")
closed = gray2rgb(closed)
if ff != "sam":
closed = gray2rgb(closed)
ax2_img = labels
if opts.edges:
ax2_img = closed
Expand All @@ -714,6 +786,7 @@ def seeds(args):
# Calculate region properties
rp = regionprops(labels)
rp = [x for x in rp if min_size <= x.area <= max_size]
rp.sort(key=lambda x: x.area, reverse=True)
nb_labels = len(rp)
logger.debug("A total of %d objects identified.", nb_labels)
objects = []
Expand Down

0 comments on commit 621eb47

Please sign in to comment.