Skip to content

Commit

Permalink
Merge branch 'master' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Sep 27, 2023
2 parents d539307 + 855ceda commit a2269b1
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 104 deletions.
5 changes: 3 additions & 2 deletions finetuning/livecell_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@ def get_dataloaders(patch_shape, data_path, cell_type=None):

def finetune_livecell(args):
"""Example code for finetuning SAM on LiveCELL"""
# override this (below) if you have some more complex set-up and need to specify the exact gpu
device = "cuda" if torch.cuda.is_available() else "cpu"

# training settings:
model_type = args.model_type
checkpoint_path = None # override this to start training from a custom checkpoint
device = "cuda" # override this if you have some more complex set-up and need to specify the exact gpu
patch_shape = (520, 740) # the patch shape for training
n_objects_per_batch = 25 # this is the number of objects per batch that will be sampled

Expand Down Expand Up @@ -105,7 +106,7 @@ def main():
)
parser.add_argument(
"--export_path", "-e",
help="Where to export the finetuned model to. The exported model can be use din the annotation tools."
help="Where to export the finetuned model to. The exported model can be used in the annotation tools."
)
args = parser.parse_args()
finetune_livecell(args)
Expand Down
2 changes: 1 addition & 1 deletion micro_sam/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __getitem__(self, index):

def mask_data_to_segmentation(
masks: List[Dict[str, Any]],
shape: tuple[int, ...],
shape: Tuple[int, ...],
with_background: bool,
min_object_size: int = 0,
) -> np.ndarray:
Expand Down
6 changes: 3 additions & 3 deletions micro_sam/precompute_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def precompute_state(
model_type: str = util._DEFAULT_MODEL,
checkpoint_path: Optional[Union[os.PathLike, str]] = None,
key: Optional[str] = None,
ndim: Union[int] = None,
ndim: Optional[int] = None,
tile_shape: Optional[Tuple[int, int]] = None,
halo: Optional[Tuple[int, int]] = None,
precompute_amg_state: bool = False,
Expand Down Expand Up @@ -158,8 +158,8 @@ def main():
parser.add_argument(
"--halo", nargs="+", type=int, help="The halo for using tiled prediction", default=None
)
parser.add_argument("-n", "--ndim")
parser.add_argument("-p", "--precompute_amg_state")
parser.add_argument("-n", "--ndim", type=int)
parser.add_argument("-p", "--precompute_amg_state", action="store_true")

args = parser.parse_args()
precompute_state(
Expand Down
8 changes: 4 additions & 4 deletions micro_sam/prompt_based_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

import warnings
from typing import Optional
from typing import Optional, Tuple

import numpy as np
from nifty.tools import blocking
Expand Down Expand Up @@ -308,7 +308,7 @@ def segment_from_mask(
use_box: bool = True,
use_mask: bool = True,
use_points: bool = False,
original_size: Optional[tuple[int, ...]] = None,
original_size: Optional[Tuple[int, ...]] = None,
multimask_output: bool = False,
return_all: bool = False,
return_logits: bool = False,
Expand Down Expand Up @@ -401,7 +401,7 @@ def segment_from_box(
box: np.ndarray,
image_embeddings: Optional[util.ImageEmbeddings] = None,
i: Optional[int] = None,
original_size: Optional[tuple[int, ...]] = None,
original_size: Optional[Tuple[int, ...]] = None,
multimask_output: bool = False,
return_all: bool = False,
):
Expand Down Expand Up @@ -443,7 +443,7 @@ def segment_from_box_and_points(
labels: np.ndarray,
image_embeddings: Optional[util.ImageEmbeddings] = None,
i: Optional[int] = None,
original_size: Optional[tuple[int, ...]] = None,
original_size: Optional[Tuple[int, ...]] = None,
multimask_output: bool = False,
return_all: bool = False,
):
Expand Down
31 changes: 13 additions & 18 deletions micro_sam/prompt_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
"""

from collections.abc import Mapping
from typing import Optional, Tuple
from typing import List, Optional, Tuple

import numpy as np
from scipy.ndimage import binary_dilation

import torch
from kornia.morphology import dilation


class PointAndBoxPromptGenerator:
Expand Down Expand Up @@ -156,10 +155,10 @@ def __call__(
self,
segmentation: np.ndarray,
segmentation_id: int,
bbox_coordinates: Mapping[int, tuple],
bbox_coordinates: Mapping[int, Tuple],
center_coordinates: Optional[Mapping[int, np.ndarray]] = None
) -> tuple[
Optional[list[tuple]], Optional[list[int]], Optional[list[tuple]], np.ndarray
) -> Tuple[
Optional[List[Tuple]], Optional[List[int]], Optional[List[Tuple]], np.ndarray
]:
"""Generate the prompts for one object in the segmentation.
Expand Down Expand Up @@ -197,7 +196,7 @@ class IterativePromptGenerator:
"""
def _get_positive_points(self, pos_region, overlap_region):
positive_locations = [torch.where(pos_reg) for pos_reg in pos_region]
# we may have objects withput a positive region (= missing true foreground)
# we may have objects without a positive region (= missing true foreground)
# in this case we just sample a point where the model was already correct
positive_locations = [
torch.where(ovlp_reg) if len(pos_loc[0]) == 0 else pos_loc
Expand Down Expand Up @@ -230,13 +229,13 @@ def _get_negative_points(self, negative_region_batched, true_object_batched, gt_
bbox = torch.stack([torch.min(x_coords), torch.min(y_coords),
torch.max(x_coords) + 1, torch.max(y_coords) + 1])
bbox_mask = torch.zeros_like(true_object).squeeze(0)
bbox_mask[bbox[0]:bbox[2], bbox[1]:bbox[3]] = 1

custom_df = 3 # custom dilation factor to perform dilation by expanding the pixels of bbox
bbox_mask[max(bbox[0] - custom_df, 0): min(bbox[2] + custom_df, gt.shape[-2]),
max(bbox[1] - custom_df, 0): min(bbox[3] + custom_df, gt.shape[-1])] = 1
bbox_mask = bbox_mask[None].to(device)

# NOTE: FIX: here we add dilation to the bbox because in some case we couldn't find objects at all
# TODO: just expand the pixels of bbox
dilated_bbox_mask = dilation(bbox_mask[None], torch.ones(3, 3).to(device)).squeeze(0)
background_mask = abs(dilated_bbox_mask - true_object)
background_mask = torch.abs(bbox_mask - true_object)
tmp_neg_loc = torch.where(background_mask)

# there is a chance that the object is small to not return a decent-sized bounding box
Expand All @@ -258,16 +257,12 @@ def __call__(
self,
gt: torch.Tensor,
object_mask: torch.Tensor,
current_points: torch.Tensor,
current_labels: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Generate the prompts for each object iteratively in the segmentation.
Args:
The groundtruth segmentation.
The predicted objects.
The current points.
Thr current labels.
Returns:
The updated point prompt coordinates.
Expand All @@ -278,7 +273,7 @@ def __call__(

true_object = gt.to(device)
expected_diff = (object_mask - true_object)
neg_region = (expected_diff == 1).to(torch.float)
neg_region = (expected_diff == 1).to(torch.float32)
pos_region = (expected_diff == -1)
overlap_region = torch.logical_and(object_mask == 1, true_object == 1).to(torch.float32)

Expand All @@ -290,7 +285,7 @@ def __call__(
neg_coordinates = torch.tensor(neg_coordinates)[:, None]
pos_labels, neg_labels = torch.tensor(pos_labels)[:, None], torch.tensor(neg_labels)[:, None]

net_coords = torch.cat([current_points, pos_coordinates, neg_coordinates], dim=1)
net_labels = torch.cat([current_labels, pos_labels, neg_labels], dim=1)
net_coords = torch.cat([pos_coordinates, neg_coordinates], dim=1)
net_labels = torch.cat([pos_labels, neg_labels], dim=1)

return net_coords, net_labels
94 changes: 20 additions & 74 deletions micro_sam/training/sam_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
import torch
import torch_em

from kornia.morphology import dilation
from torchvision.utils import make_grid
from torch_em.trainer.logger_base import TorchEmLogger

from ..prompt_generators import IterativePromptGenerator


class SamTrainer(torch_em.trainer.DefaultTrainer):
"""Trainer class for training the Segment Anything model.
Expand Down Expand Up @@ -37,6 +38,7 @@ def __init__(
n_objects_per_batch: Optional[int] = None,
mse_loss: torch.nn.Module = torch.nn.MSELoss(),
_sigmoid: torch.nn.Module = torch.nn.Sigmoid(),
prompt_generator=IterativePromptGenerator(),
**kwargs
):
super().__init__(**kwargs)
Expand All @@ -45,6 +47,7 @@ def __init__(
self._sigmoid = _sigmoid
self.n_objects_per_batch = n_objects_per_batch
self.n_sub_iteration = n_sub_iteration
self.prompt_generator = prompt_generator
self._kwargs = kwargs

def _get_prompt_and_multimasking_choices(self, current_iteration):
Expand Down Expand Up @@ -131,8 +134,7 @@ def _get_net_loss(self, batched_outputs, y, sampled_ids):

# outer loop is over the batch (different image/patch predictions)
for m_, y_, ids_, predicted_iou_ in zip(masks, y, sampled_ids, predicted_iou_values):
per_object_dice_scores = []
per_object_iou_scores = []
per_object_dice_scores, per_object_iou_scores = [], []

# inner loop is over the channels, this corresponds to the different predicted objects
for i, (predicted_obj, predicted_iou) in enumerate(zip(m_, predicted_iou_)):
Expand All @@ -157,7 +159,7 @@ def _get_net_loss(self, batched_outputs, y, sampled_ids):
return loss, mask_loss, iou_regression_loss, mean_model_iou

def _postprocess_outputs(self, masks):
""" masks look like -> (B, 1, X, Y)
""" "masks" look like -> (B, 1, X, Y)
where, B is the number of objects, (X, Y) is the input image shape
"""
instance_labels = []
Expand All @@ -174,8 +176,7 @@ def _get_val_metric(self, batched_outputs, sampled_binary_y):
masks = [m["masks"] for m in batched_outputs]
pred_labels = self._postprocess_outputs(masks)

# we do the condition below to adapt w.r.t. the multimask output
# to select the "objectively" best response
# we do the condition below to adapt w.r.t. the multimask output to select the "objectively" best response
if pred_labels.dim() == 5:
metric = min([self.metric(pred_labels[:, :, i, :, :], sampled_binary_y.to(self.device))
for i in range(pred_labels.shape[2])])
Expand All @@ -192,10 +193,7 @@ def _update_masks(self, batched_inputs, y, sampled_binary_y, sampled_ids, num_su
input_images = torch.stack([self.model.preprocess(x=x["image"].to(self.device)) for x in batched_inputs], dim=0)
image_embeddings = self.model.image_embeddings_oft(input_images)

loss = 0.0
mask_loss = 0.0
iou_regression_loss = 0.0
mean_model_iou = 0.0
loss, mask_loss, iou_regression_loss, mean_model_iou = 0.0, 0.0, 0.0, 0.0

# this loop takes care of the idea of sub-iterations, i.e. the number of times we iterate over each batch
for i in range(0, num_subiter):
Expand All @@ -219,9 +217,7 @@ def _update_masks(self, batched_inputs, y, sampled_binary_y, sampled_ids, num_su
mask, l_mask = [], []
for _m, _l, _iou in zip(m["masks"], m["low_res_masks"], m["iou_predictions"]):
best_iou_idx = torch.argmax(_iou)

best_mask, best_logits = _m[best_iou_idx], _l[best_iou_idx]
best_mask, best_logits = best_mask[None], best_logits[None]
best_mask, best_logits = _m[best_iou_idx][None], _l[best_iou_idx][None]
mask.append(self._sigmoid(best_mask))
l_mask.append(best_logits)

Expand All @@ -244,57 +240,13 @@ def _update_masks(self, batched_inputs, y, sampled_binary_y, sampled_ids, num_su
def _get_updated_points_per_mask_per_subiter(self, masks, sampled_binary_y, batched_inputs, logits_masks):
# here, we get the pair-per-batch of predicted and true elements (and also the "batched_inputs")
for x1, x2, _inp, logits in zip(masks, sampled_binary_y, batched_inputs, logits_masks):
net_coords, net_labels = [], []

# here, we get each object in the pairs and do the point choices per-object
for pred_obj, true_obj in zip(x1, x2):
true_obj = true_obj.to(self.device)

expected_diff = (pred_obj - true_obj)

neg_region = (expected_diff == 1).to(torch.float32)
pos_region = (expected_diff == -1)
overlap_region = torch.logical_and(pred_obj == 1, true_obj == 1).to(torch.float32)

# POSITIVE POINTS
tmp_pos_loc = torch.where(pos_region)
if torch.stack(tmp_pos_loc).shape[-1] == 0:
tmp_pos_loc = torch.where(overlap_region)

pos_index = np.random.choice(len(tmp_pos_loc[1]))
pos_coordinates = int(tmp_pos_loc[1][pos_index]), int(tmp_pos_loc[2][pos_index])
pos_coordinates = pos_coordinates[::-1]
pos_labels = 1

# NEGATIVE POINTS
tmp_neg_loc = torch.where(neg_region)
if torch.stack(tmp_neg_loc).shape[-1] == 0:
tmp_true_loc = torch.where(true_obj)
x_coords, y_coords = tmp_true_loc[1], tmp_true_loc[2]
bbox = torch.stack([torch.min(x_coords), torch.min(y_coords),
torch.max(x_coords) + 1, torch.max(y_coords) + 1])
bbox_mask = torch.zeros_like(true_obj).squeeze(0)
bbox_mask[bbox[0]:bbox[2], bbox[1]:bbox[3]] = 1
bbox_mask = bbox_mask[None].to(self.device)

dilated_bbox_mask = dilation(bbox_mask[None], torch.ones(3, 3).to(self.device)).squeeze(0)
background_mask = abs(dilated_bbox_mask - true_obj)
tmp_neg_loc = torch.where(background_mask)

neg_index = np.random.choice(len(tmp_neg_loc[1]))
neg_coordinates = int(tmp_neg_loc[1][neg_index]), int(tmp_neg_loc[2][neg_index])
neg_coordinates = neg_coordinates[::-1]
neg_labels = 0

net_coords.append([pos_coordinates, neg_coordinates])
net_labels.append([pos_labels, neg_labels])

if "point_labels" in _inp.keys():
updated_point_coords = torch.cat([_inp["point_coords"], torch.tensor(net_coords)], dim=1)
updated_point_labels = torch.cat([_inp["point_labels"], torch.tensor(net_labels)], dim=1)
else:
updated_point_coords = torch.tensor(net_coords)
updated_point_labels = torch.tensor(net_labels)
net_coords, net_labels = self.prompt_generator(x2, x1)

updated_point_coords = torch.cat([_inp["point_coords"], net_coords], dim=1) \
if "point_coords" in _inp.keys() else net_coords
updated_point_labels = torch.cat([_inp["point_labels"], net_labels], dim=1) \
if "point_labels" in _inp.keys() else net_labels

_inp["point_coords"] = updated_point_coords
_inp["point_labels"] = updated_point_labels
Expand All @@ -305,9 +257,8 @@ def _get_updated_points_per_mask_per_subiter(self, masks, sampled_binary_y, batc
#

def _update_samples_for_gt_instances(self, y, n_samples):
num_instances_gt = [len(torch.unique(_y)) for _y in y]
if n_samples > min(num_instances_gt):
n_samples = min(num_instances_gt) - 1
num_instances_gt = torch.amax(y, dim=(1, 2, 3))
n_samples = min(num_instances_gt) if n_samples > min(num_instances_gt) else n_samples
return n_samples

def _train_epoch_impl(self, progress, forward_context, backprop):
Expand All @@ -320,8 +271,7 @@ def _train_epoch_impl(self, progress, forward_context, backprop):
self.optimizer.zero_grad()

with forward_context():
n_samples = self.n_objects_per_batch
n_samples = self._update_samples_for_gt_instances(y, n_samples)
n_samples = self._update_samples_for_gt_instances(y, self.n_objects_per_batch)

n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices(self._iteration)

Expand Down Expand Up @@ -375,16 +325,12 @@ def _train_epoch_impl(self, progress, forward_context, backprop):
def _validate_impl(self, forward_context):
self.model.eval()

metric_val = 0.0
loss_val = 0.0
model_iou_val = 0.0
val_iteration = 0
metric_val, loss_val, model_iou_val, val_iteration = 0.0, 0.0, 0.0, 0.0

with torch.no_grad():
for x, y in self.val_loader:
with forward_context():
n_samples = self.n_objects_per_batch
n_samples = self._update_samples_for_gt_instances(y, n_samples)
n_samples = self._update_samples_for_gt_instances(y, self.n_objects_per_batch)

(n_pos, n_neg,
get_boxes, multimask_output) = self._get_prompt_and_multimasking_choices_for_val(val_iteration)
Expand Down
2 changes: 1 addition & 1 deletion micro_sam/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def _project_tiled_embeddings(image_embeddings):

def project_embeddings_for_visualization(
image_embeddings: ImageEmbeddings
) -> tuple[np.ndarray, tuple[float, ...]]:
) -> Tuple[np.ndarray, Tuple[float, ...]]:
"""Project image embeddings to pixel-wise PCA.
Args:
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ console_scripts =
micro_sam.annotator_3d = micro_sam.sam_annotator.annotator_3d:main
micro_sam.annotator_tracking = micro_sam.sam_annotator.annotator_tracking:main
micro_sam.image_series_annotator = micro_sam.sam_annotator.image_series_annotator:main
micro_sam.precompute_embeddings = micro_sam.util:main
micro_sam.precompute_embeddings = micro_sam.precompute_state:main

# make sure it gets included in your package
[options.package_data]
Expand Down

0 comments on commit a2269b1

Please sign in to comment.