Skip to content

Commit

Permalink
Merge pull request #192 from computational-cell-analytics/train-updat…
Browse files Browse the repository at this point in the history
…e-aa

Refactoring Iterative Training Scheme
  • Loading branch information
constantinpape authored Sep 20, 2023
2 parents 3025b6c + 158bc25 commit 66165b6
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 90 deletions.
5 changes: 3 additions & 2 deletions finetuning/livecell_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@ def get_dataloaders(patch_shape, data_path, cell_type=None):

def finetune_livecell(args):
"""Example code for finetuning SAM on LiveCELL"""
# override this (below) if you have some more complex set-up and need to specify the exact gpu
device = "cuda" if torch.cuda.is_available() else "cpu"

# training settings:
model_type = args.model_type
checkpoint_path = None # override this to start training from a custom checkpoint
device = "cuda" # override this if you have some more complex set-up and need to specify the exact gpu
patch_shape = (520, 740) # the patch shape for training
n_objects_per_batch = 25 # this is the number of objects per batch that will be sampled

Expand Down Expand Up @@ -105,7 +106,7 @@ def main():
)
parser.add_argument(
"--export_path", "-e",
help="Where to export the finetuned model to. The exported model can be use din the annotation tools."
help="Where to export the finetuned model to. The exported model can be used in the annotation tools."
)
args = parser.parse_args()
finetune_livecell(args)
Expand Down
23 changes: 9 additions & 14 deletions micro_sam/prompt_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from scipy.ndimage import binary_dilation

import torch
from kornia.morphology import dilation


class PointAndBoxPromptGenerator:
Expand Down Expand Up @@ -197,7 +196,7 @@ class IterativePromptGenerator:
"""
def _get_positive_points(self, pos_region, overlap_region):
positive_locations = [torch.where(pos_reg) for pos_reg in pos_region]
# we may have objects withput a positive region (= missing true foreground)
# we may have objects without a positive region (= missing true foreground)
# in this case we just sample a point where the model was already correct
positive_locations = [
torch.where(ovlp_reg) if len(pos_loc[0]) == 0 else pos_loc
Expand Down Expand Up @@ -230,13 +229,13 @@ def _get_negative_points(self, negative_region_batched, true_object_batched, gt_
bbox = torch.stack([torch.min(x_coords), torch.min(y_coords),
torch.max(x_coords) + 1, torch.max(y_coords) + 1])
bbox_mask = torch.zeros_like(true_object).squeeze(0)
bbox_mask[bbox[0]:bbox[2], bbox[1]:bbox[3]] = 1

custom_df = 3 # custom dilation factor to perform dilation by expanding the pixels of bbox
bbox_mask[max(bbox[0] - custom_df, 0): min(bbox[2] + custom_df, gt.shape[-2]),
max(bbox[1] - custom_df, 0): min(bbox[3] + custom_df, gt.shape[-1])] = 1
bbox_mask = bbox_mask[None].to(device)

# NOTE: FIX: here we add dilation to the bbox because in some case we couldn't find objects at all
# TODO: just expand the pixels of bbox
dilated_bbox_mask = dilation(bbox_mask[None], torch.ones(3, 3).to(device)).squeeze(0)
background_mask = abs(dilated_bbox_mask - true_object)
background_mask = torch.abs(bbox_mask - true_object)
tmp_neg_loc = torch.where(background_mask)

# there is a chance that the object is small to not return a decent-sized bounding box
Expand All @@ -258,16 +257,12 @@ def __call__(
self,
gt: torch.Tensor,
object_mask: torch.Tensor,
current_points: torch.Tensor,
current_labels: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Generate the prompts for each object iteratively in the segmentation.
Args:
The groundtruth segmentation.
The predicted objects.
The current points.
Thr current labels.
Returns:
The updated point prompt coordinates.
Expand All @@ -278,7 +273,7 @@ def __call__(

true_object = gt.to(device)
expected_diff = (object_mask - true_object)
neg_region = (expected_diff == 1).to(torch.float)
neg_region = (expected_diff == 1).to(torch.float32)
pos_region = (expected_diff == -1)
overlap_region = torch.logical_and(object_mask == 1, true_object == 1).to(torch.float32)

Expand All @@ -290,7 +285,7 @@ def __call__(
neg_coordinates = torch.tensor(neg_coordinates)[:, None]
pos_labels, neg_labels = torch.tensor(pos_labels)[:, None], torch.tensor(neg_labels)[:, None]

net_coords = torch.cat([current_points, pos_coordinates, neg_coordinates], dim=1)
net_labels = torch.cat([current_labels, pos_labels, neg_labels], dim=1)
net_coords = torch.cat([pos_coordinates, neg_coordinates], dim=1)
net_labels = torch.cat([pos_labels, neg_labels], dim=1)

return net_coords, net_labels
94 changes: 20 additions & 74 deletions micro_sam/training/sam_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
import torch
import torch_em

from kornia.morphology import dilation
from torchvision.utils import make_grid
from torch_em.trainer.logger_base import TorchEmLogger

from ..prompt_generators import IterativePromptGenerator


class SamTrainer(torch_em.trainer.DefaultTrainer):
"""Trainer class for training the Segment Anything model.
Expand Down Expand Up @@ -37,6 +38,7 @@ def __init__(
n_objects_per_batch: Optional[int] = None,
mse_loss: torch.nn.Module = torch.nn.MSELoss(),
_sigmoid: torch.nn.Module = torch.nn.Sigmoid(),
prompt_generator=IterativePromptGenerator(),
**kwargs
):
super().__init__(**kwargs)
Expand All @@ -45,6 +47,7 @@ def __init__(
self._sigmoid = _sigmoid
self.n_objects_per_batch = n_objects_per_batch
self.n_sub_iteration = n_sub_iteration
self.prompt_generator = prompt_generator
self._kwargs = kwargs

def _get_prompt_and_multimasking_choices(self, current_iteration):
Expand Down Expand Up @@ -131,8 +134,7 @@ def _get_net_loss(self, batched_outputs, y, sampled_ids):

# outer loop is over the batch (different image/patch predictions)
for m_, y_, ids_, predicted_iou_ in zip(masks, y, sampled_ids, predicted_iou_values):
per_object_dice_scores = []
per_object_iou_scores = []
per_object_dice_scores, per_object_iou_scores = [], []

# inner loop is over the channels, this corresponds to the different predicted objects
for i, (predicted_obj, predicted_iou) in enumerate(zip(m_, predicted_iou_)):
Expand All @@ -157,7 +159,7 @@ def _get_net_loss(self, batched_outputs, y, sampled_ids):
return loss, mask_loss, iou_regression_loss, mean_model_iou

def _postprocess_outputs(self, masks):
""" masks look like -> (B, 1, X, Y)
""" "masks" look like -> (B, 1, X, Y)
where, B is the number of objects, (X, Y) is the input image shape
"""
instance_labels = []
Expand All @@ -174,8 +176,7 @@ def _get_val_metric(self, batched_outputs, sampled_binary_y):
masks = [m["masks"] for m in batched_outputs]
pred_labels = self._postprocess_outputs(masks)

# we do the condition below to adapt w.r.t. the multimask output
# to select the "objectively" best response
# we do the condition below to adapt w.r.t. the multimask output to select the "objectively" best response
if pred_labels.dim() == 5:
metric = min([self.metric(pred_labels[:, :, i, :, :], sampled_binary_y.to(self.device))
for i in range(pred_labels.shape[2])])
Expand All @@ -192,10 +193,7 @@ def _update_masks(self, batched_inputs, y, sampled_binary_y, sampled_ids, num_su
input_images = torch.stack([self.model.preprocess(x=x["image"].to(self.device)) for x in batched_inputs], dim=0)
image_embeddings = self.model.image_embeddings_oft(input_images)

loss = 0.0
mask_loss = 0.0
iou_regression_loss = 0.0
mean_model_iou = 0.0
loss, mask_loss, iou_regression_loss, mean_model_iou = 0.0, 0.0, 0.0, 0.0

# this loop takes care of the idea of sub-iterations, i.e. the number of times we iterate over each batch
for i in range(0, num_subiter):
Expand All @@ -219,9 +217,7 @@ def _update_masks(self, batched_inputs, y, sampled_binary_y, sampled_ids, num_su
mask, l_mask = [], []
for _m, _l, _iou in zip(m["masks"], m["low_res_masks"], m["iou_predictions"]):
best_iou_idx = torch.argmax(_iou)

best_mask, best_logits = _m[best_iou_idx], _l[best_iou_idx]
best_mask, best_logits = best_mask[None], best_logits[None]
best_mask, best_logits = _m[best_iou_idx][None], _l[best_iou_idx][None]
mask.append(self._sigmoid(best_mask))
l_mask.append(best_logits)

Expand All @@ -244,57 +240,13 @@ def _update_masks(self, batched_inputs, y, sampled_binary_y, sampled_ids, num_su
def _get_updated_points_per_mask_per_subiter(self, masks, sampled_binary_y, batched_inputs, logits_masks):
# here, we get the pair-per-batch of predicted and true elements (and also the "batched_inputs")
for x1, x2, _inp, logits in zip(masks, sampled_binary_y, batched_inputs, logits_masks):
net_coords, net_labels = [], []

# here, we get each object in the pairs and do the point choices per-object
for pred_obj, true_obj in zip(x1, x2):
true_obj = true_obj.to(self.device)

expected_diff = (pred_obj - true_obj)

neg_region = (expected_diff == 1).to(torch.float32)
pos_region = (expected_diff == -1)
overlap_region = torch.logical_and(pred_obj == 1, true_obj == 1).to(torch.float32)

# POSITIVE POINTS
tmp_pos_loc = torch.where(pos_region)
if torch.stack(tmp_pos_loc).shape[-1] == 0:
tmp_pos_loc = torch.where(overlap_region)

pos_index = np.random.choice(len(tmp_pos_loc[1]))
pos_coordinates = int(tmp_pos_loc[1][pos_index]), int(tmp_pos_loc[2][pos_index])
pos_coordinates = pos_coordinates[::-1]
pos_labels = 1

# NEGATIVE POINTS
tmp_neg_loc = torch.where(neg_region)
if torch.stack(tmp_neg_loc).shape[-1] == 0:
tmp_true_loc = torch.where(true_obj)
x_coords, y_coords = tmp_true_loc[1], tmp_true_loc[2]
bbox = torch.stack([torch.min(x_coords), torch.min(y_coords),
torch.max(x_coords) + 1, torch.max(y_coords) + 1])
bbox_mask = torch.zeros_like(true_obj).squeeze(0)
bbox_mask[bbox[0]:bbox[2], bbox[1]:bbox[3]] = 1
bbox_mask = bbox_mask[None].to(self.device)

dilated_bbox_mask = dilation(bbox_mask[None], torch.ones(3, 3).to(self.device)).squeeze(0)
background_mask = abs(dilated_bbox_mask - true_obj)
tmp_neg_loc = torch.where(background_mask)

neg_index = np.random.choice(len(tmp_neg_loc[1]))
neg_coordinates = int(tmp_neg_loc[1][neg_index]), int(tmp_neg_loc[2][neg_index])
neg_coordinates = neg_coordinates[::-1]
neg_labels = 0

net_coords.append([pos_coordinates, neg_coordinates])
net_labels.append([pos_labels, neg_labels])

if "point_labels" in _inp.keys():
updated_point_coords = torch.cat([_inp["point_coords"], torch.tensor(net_coords)], dim=1)
updated_point_labels = torch.cat([_inp["point_labels"], torch.tensor(net_labels)], dim=1)
else:
updated_point_coords = torch.tensor(net_coords)
updated_point_labels = torch.tensor(net_labels)
net_coords, net_labels = self.prompt_generator(x2, x1)

updated_point_coords = torch.cat([_inp["point_coords"], net_coords], dim=1) \
if "point_coords" in _inp.keys() else net_coords
updated_point_labels = torch.cat([_inp["point_labels"], net_labels], dim=1) \
if "point_labels" in _inp.keys() else net_labels

_inp["point_coords"] = updated_point_coords
_inp["point_labels"] = updated_point_labels
Expand All @@ -305,9 +257,8 @@ def _get_updated_points_per_mask_per_subiter(self, masks, sampled_binary_y, batc
#

def _update_samples_for_gt_instances(self, y, n_samples):
num_instances_gt = [len(torch.unique(_y)) for _y in y]
if n_samples > min(num_instances_gt):
n_samples = min(num_instances_gt) - 1
num_instances_gt = torch.amax(y, dim=(1, 2, 3))
n_samples = min(num_instances_gt) if n_samples > min(num_instances_gt) else n_samples
return n_samples

def _train_epoch_impl(self, progress, forward_context, backprop):
Expand All @@ -320,8 +271,7 @@ def _train_epoch_impl(self, progress, forward_context, backprop):
self.optimizer.zero_grad()

with forward_context():
n_samples = self.n_objects_per_batch
n_samples = self._update_samples_for_gt_instances(y, n_samples)
n_samples = self._update_samples_for_gt_instances(y, self.n_objects_per_batch)

n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices(self._iteration)

Expand Down Expand Up @@ -375,16 +325,12 @@ def _train_epoch_impl(self, progress, forward_context, backprop):
def _validate_impl(self, forward_context):
self.model.eval()

metric_val = 0.0
loss_val = 0.0
model_iou_val = 0.0
val_iteration = 0
metric_val, loss_val, model_iou_val, val_iteration = 0.0, 0.0, 0.0, 0.0

with torch.no_grad():
for x, y in self.val_loader:
with forward_context():
n_samples = self.n_objects_per_batch
n_samples = self._update_samples_for_gt_instances(y, n_samples)
n_samples = self._update_samples_for_gt_instances(y, self.n_objects_per_batch)

(n_pos, n_neg,
get_boxes, multimask_output) = self._get_prompt_and_multimasking_choices_for_val(val_iteration)
Expand Down

0 comments on commit 66165b6

Please sign in to comment.