Skip to content

Commit

Permalink
Update preprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 committed Jul 25, 2024
1 parent c5777be commit fa4017e
Showing 1 changed file with 21 additions and 16 deletions.
37 changes: 21 additions & 16 deletions micro_sam/training/sam_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,34 @@ def _compute_iou(self, pred, true, eps=1e-7):
iou = overlap / (union + eps)
return iou

def preprocess_one_hot_masks(self, y_one_hot):
"""
"""
# Convert the labels to "low_res_mask" shape
# First step is to use the logic from `ResizeLongestSide` to resize the longest side.
target_length = self.model.transform.target_length
target_shape = self.model.transform.get_preprocess_shape(y_one_hot.shape[2], y_one_hot.shape[3], target_length)
y_one_hot = F.interpolate(input=y_one_hot, size=target_shape)
# Next, we pad the remaining region to (1024, 1024)
h, w = y_one_hot.shape[-2:]
padh = self.model.sam.image_encoder.img_size - h
padw = self.model.sam.image_encoder.img_size - w
y_one_hot = F.pad(input=y_one_hot, pad=(0, padw, 0, padh))
# Finally, let's resize the labels to the desired shape (i.e. (256, 256))
y_one_hot = F.interpolate(input=y_one_hot, size=(256, 256))

return y_one_hot

def _compute_loss(self, batched_outputs, y_one_hot):
"""Compute the loss for one iteration. The loss is made up of two components:
- The mask loss: dice score between the predicted masks and targets.
- The IOU loss: L2 loss between the predicted IOU and the actual IOU of prediction and target.
"""
mask_loss, iou_regression_loss = 0.0, 0.0

# TODO
y_one_hot = self.preprocess_one_hot_masks(y_one_hot)

# Loop over the batch.
for batch_output, targets in zip(batched_outputs, y_one_hot):
predicted_objects = torch.sigmoid(batch_output["low_res_masks"])
Expand Down Expand Up @@ -276,22 +297,6 @@ def _preprocess_batch(self, batched_inputs, y, sampled_ids):
# number of objects across the batch.
n_objects = min(len(ids) for ids in sampled_ids)

original_instance_ids = list(torch.unique(y))
# Convert the labels to "low_res_mask" shape
# First step is to use the logic from `ResizeLongestSide` to resize the longest side.
target_length = self.model.transform.target_length
target_shape = self.model.transform.get_preprocess_shape(y.shape[2], y.shape[3], target_length)
y = F.interpolate(input=y, size=target_shape)
# Next, we pad the remaining region to (1024, 1024)
h, w = y.shape[-2:]
padh = self.model.sam.image_encoder.img_size - h
padw = self.model.sam.image_encoder.img_size - w
y = F.pad(input=y, pad=(0, padw, 0, padh))
# Finally, let's resize the labels to the desired shape (i.e. (256, 256))
y = F.interpolate(input=y, size=(256, 256))

assert list(torch.unique(y)) == original_instance_ids

y = y.to(self.device)
# Compute the one hot targets for the seg-id.
y_one_hot = torch.stack([
Expand Down

0 comments on commit fa4017e

Please sign in to comment.