Skip to content

Commit

Permalink
Update trainers to compute loss over low_res_mask
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 committed Jul 25, 2024
1 parent 3bd8265 commit c5777be
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 51 deletions.
29 changes: 12 additions & 17 deletions micro_sam/training/joint_sam_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from collections import OrderedDict

import torch
from torchvision.utils import make_grid

from .sam_trainer import SamTrainer

Expand Down Expand Up @@ -85,8 +84,8 @@ def _train_epoch_impl(self, progress, forward_context, backprop):

with forward_context():
# 1. train for the interactive segmentation
(loss, mask_loss, iou_regression_loss, model_iou,
sampled_binary_y) = self._interactive_train_iteration(x, labels_instances)
(loss, mask_loss, iou_regression_loss,
model_iou) = self._interactive_train_iteration(x, labels_instances)

backprop(loss)

Expand All @@ -100,10 +99,9 @@ def _train_epoch_impl(self, progress, forward_context, backprop):

if self.logger is not None:
lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
samples = sampled_binary_y if self._iteration % self.log_image_interval == 0 else None
self.logger.log_train(
self._iteration, loss, lr, x, labels_instances, samples,
mask_loss, iou_regression_loss, model_iou, unetr_loss
self._iteration, loss, lr, x, labels_instances, mask_loss,
iou_regression_loss, model_iou, unetr_loss
)

self._iteration += 1
Expand Down Expand Up @@ -133,7 +131,7 @@ def _validate_impl(self, forward_context):
with forward_context():
# 1. validate for the interactive segmentation
(loss, mask_loss, iou_regression_loss, model_iou,
sampled_binary_y, metric) = self._interactive_val_iteration(x, labels_instances, val_iteration)
metric) = self._interactive_val_iteration(x, labels_instances, val_iteration)

with forward_context():
# 2. validate for the automatic instance segmentation
Expand All @@ -150,7 +148,7 @@ def _validate_impl(self, forward_context):

if self.logger is not None:
self.logger.log_validation(
self._iteration, metric_val, loss_val, x, labels_instances, sampled_binary_y,
self._iteration, metric_val, loss_val, x, labels_instances,
mask_loss, iou_regression_loss, model_iou_val, unetr_loss
)

Expand All @@ -161,25 +159,22 @@ class JointSamLogger(TorchEmLogger):
"""@private"""
def __init__(self, trainer, save_root, **unused_kwargs):
super().__init__(trainer, save_root)
self.log_dir = f"./logs/{trainer.name}" if save_root is None else\
os.path.join(save_root, "logs", trainer.name)
self.log_dir = f"./logs/{trainer.name}" if save_root is None else os.path.join(save_root, "logs", trainer.name)
os.makedirs(self.log_dir, exist_ok=True)

self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir)
self.log_image_interval = trainer.log_image_interval

def add_image(self, x, y, samples, name, step):
def add_image(self, x, y, name, step):
selection = np.s_[0] if x.ndim == 4 else np.s_[0, :, x.shape[2] // 2]

image = normalize_im(x[selection].cpu())

self.tb.add_image(tag=f"{name}/input", img_tensor=image, global_step=step)
self.tb.add_image(tag=f"{name}/target", img_tensor=y[selection], global_step=step)
sample_grid = make_grid([sample[0] for sample in samples], nrow=4, padding=4)
self.tb.add_image(tag=f"{name}/samples", img_tensor=sample_grid, global_step=step)

def log_train(
self, step, loss, lr, x, y, samples, mask_loss, iou_regression_loss, model_iou, instance_loss
self, step, loss, lr, x, y, mask_loss, iou_regression_loss, model_iou, instance_loss
):
self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step)
self.tb.add_scalar(tag="train/mask_loss", scalar_value=mask_loss, global_step=step)
Expand All @@ -188,15 +183,15 @@ def log_train(
self.tb.add_scalar(tag="train/instance_loss", scalar_value=instance_loss, global_step=step)
self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step)
if step % self.log_image_interval == 0:
self.add_image(x, y, samples, "train", step)
self.add_image(x, y, "train", step)

def log_validation(
self, step, metric, loss, x, y, samples, mask_loss, iou_regression_loss, model_iou, instance_loss
self, step, metric, loss, x, y, mask_loss, iou_regression_loss, model_iou, instance_loss
):
self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step)
self.tb.add_scalar(tag="validation/mask_loss", scalar_value=mask_loss, global_step=step)
self.tb.add_scalar(tag="validation/iou_loss", scalar_value=iou_regression_loss, global_step=step)
self.tb.add_scalar(tag="validation/model_iou", scalar_value=model_iou, global_step=step)
self.tb.add_scalar(tag="train/instance_loss", scalar_value=instance_loss, global_step=step)
self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step)
self.add_image(x, y, samples, "validation", step)
self.add_image(x, y, "validation", step)
54 changes: 32 additions & 22 deletions micro_sam/training/sam_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from typing import Optional

import numpy as np

import torch
import torch_em
from torch.nn import functional as F

from torchvision.utils import make_grid
import torch_em
from torch_em.trainer.logger_base import TorchEmLogger

from ..prompt_generators import PromptGeneratorBase, IterativePromptGenerator
Expand Down Expand Up @@ -126,8 +127,7 @@ def _compute_loss(self, batched_outputs, y_one_hot):

# Loop over the batch.
for batch_output, targets in zip(batched_outputs, y_one_hot):

predicted_objects = torch.sigmoid(batch_output["masks"])
predicted_objects = torch.sigmoid(batch_output["low_res_masks"])
# Compute the dice scores for the 1 or 3 predicted masks per true object (outer loop).
# We swap the axes that go into the dice loss so that the object axis
# corresponds to the channel axes. This ensures that the dice is computed
Expand Down Expand Up @@ -276,6 +276,22 @@ def _preprocess_batch(self, batched_inputs, y, sampled_ids):
# number of objects across the batch.
n_objects = min(len(ids) for ids in sampled_ids)

original_instance_ids = list(torch.unique(y))
# Convert the labels to "low_res_mask" shape
# First step is to use the logic from `ResizeLongestSide` to resize the longest side.
target_length = self.model.transform.target_length
target_shape = self.model.transform.get_preprocess_shape(y.shape[2], y.shape[3], target_length)
y = F.interpolate(input=y, size=target_shape)
# Next, we pad the remaining region to (1024, 1024)
h, w = y.shape[-2:]
padh = self.model.sam.image_encoder.img_size - h
padw = self.model.sam.image_encoder.img_size - w
y = F.pad(input=y, pad=(0, padw, 0, padh))
# Finally, let's resize the labels to the desired shape (i.e. (256, 256))
y = F.interpolate(input=y, size=(256, 256))

assert list(torch.unique(y)) == original_instance_ids

y = y.to(self.device)
# Compute the one hot targets for the seg-id.
y_one_hot = torch.stack([
Expand All @@ -300,7 +316,7 @@ def _interactive_train_iteration(self, x, y):
batched_inputs, y_one_hot,
num_subiter=self.n_sub_iteration, multimask_output=multimask_output
)
return loss, mask_loss, iou_regression_loss, model_iou, y_one_hot
return loss, mask_loss, iou_regression_loss, model_iou

def _check_input_normalization(self, x, input_check_done):
# The expected data range of the SAM model is 8bit (0-255).
Expand Down Expand Up @@ -335,16 +351,13 @@ def _train_epoch_impl(self, progress, forward_context, backprop):
self.optimizer.zero_grad()

with forward_context():
(loss, mask_loss, iou_regression_loss, model_iou,
sampled_binary_y) = self._interactive_train_iteration(x, y)
(loss, mask_loss, iou_regression_loss, model_iou) = self._interactive_train_iteration(x, y)

backprop(loss)

if self.logger is not None:
lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
samples = sampled_binary_y if self._iteration % self.log_image_interval == 0 else None
self.logger.log_train(self._iteration, loss, lr, x, y, samples,
mask_loss, iou_regression_loss, model_iou)
self.logger.log_train(self._iteration, loss, lr, x, y, mask_loss, iou_regression_loss, model_iou)

self._iteration += 1
n_iter += 1
Expand Down Expand Up @@ -374,7 +387,7 @@ def _interactive_val_iteration(self, x, y, val_iteration):
metric = mask_loss
model_iou = torch.mean(torch.stack([m["iou_predictions"] for m in batched_outputs]))

return loss, mask_loss, iou_regression_loss, model_iou, y_one_hot, metric
return loss, mask_loss, iou_regression_loss, model_iou, metric

def _validate_impl(self, forward_context):
self.model.eval()
Expand All @@ -389,8 +402,8 @@ def _validate_impl(self, forward_context):
input_check_done = self._check_input_normalization(x, input_check_done)

with forward_context():
(loss, mask_loss, iou_regression_loss, model_iou,
sampled_binary_y, metric) = self._interactive_val_iteration(x, y, val_iteration)
(loss, mask_loss, iou_regression_loss,
model_iou, metric) = self._interactive_val_iteration(x, y, val_iteration)

loss_val += loss.item()
metric_val += metric.item()
Expand All @@ -405,8 +418,7 @@ def _validate_impl(self, forward_context):

if self.logger is not None:
self.logger.log_validation(
self._iteration, metric_val, loss_val, x, y,
sampled_binary_y, mask_loss, iou_regression_loss, model_iou_val
self._iteration, metric_val, loss_val, x, y, mask_loss, iou_regression_loss, model_iou_val
)

return metric_val
Expand All @@ -423,25 +435,23 @@ def __init__(self, trainer, save_root, **unused_kwargs):
self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir)
self.log_image_interval = trainer.log_image_interval

def add_image(self, x, y, samples, name, step):
def add_image(self, x, y, name, step):
self.tb.add_image(tag=f"{name}/input", img_tensor=x[0], global_step=step)
self.tb.add_image(tag=f"{name}/target", img_tensor=y[0], global_step=step)
sample_grid = make_grid([sample[0] for sample in samples], nrow=4, padding=4)
self.tb.add_image(tag=f"{name}/samples", img_tensor=sample_grid, global_step=step)

def log_train(self, step, loss, lr, x, y, samples, mask_loss, iou_regression_loss, model_iou):
def log_train(self, step, loss, lr, x, y, mask_loss, iou_regression_loss, model_iou):
self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step)
self.tb.add_scalar(tag="train/mask_loss", scalar_value=mask_loss, global_step=step)
self.tb.add_scalar(tag="train/iou_loss", scalar_value=iou_regression_loss, global_step=step)
self.tb.add_scalar(tag="train/model_iou", scalar_value=model_iou, global_step=step)
self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step)
if step % self.log_image_interval == 0:
self.add_image(x, y, samples, "train", step)
self.add_image(x, y, "train", step)

def log_validation(self, step, metric, loss, x, y, samples, mask_loss, iou_regression_loss, model_iou):
def log_validation(self, step, metric, loss, x, y, mask_loss, iou_regression_loss, model_iou):
self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step)
self.tb.add_scalar(tag="validation/mask_loss", scalar_value=mask_loss, global_step=step)
self.tb.add_scalar(tag="validation/iou_loss", scalar_value=iou_regression_loss, global_step=step)
self.tb.add_scalar(tag="validation/model_iou", scalar_value=model_iou, global_step=step)
self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step)
self.add_image(x, y, samples, "validation", step)
self.add_image(x, y, "validation", step)
25 changes: 13 additions & 12 deletions micro_sam/training/trainable_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,18 @@ class TrainableSAM(nn.Module):
Args:
sam: The SegmentAnything Model.
device: The device for training.
upsampled_masks: Whether to return the output masks in the original input shape.
"""
def __init__(
self,
sam: Sam,
device: Union[str, torch.device],
upsampled_masks: bool = True,
) -> None:
super().__init__()
self.sam = sam
self.device = device
self.upsampled_masks = upsampled_masks
self.transform = ResizeLongestSide(sam.image_encoder.img_size)

def preprocess(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]:
Expand Down Expand Up @@ -111,18 +114,16 @@ def forward(
multimask_output=multimask_output,
)

masks = self.sam.postprocess_masks(
low_res_masks,
input_size=image_record["input_size"],
original_size=image_record["original_size"],
)
curr_outputs = {"low_res_masks": low_res_masks, "iou_predictions": iou_predictions}

outputs.append(
{
"low_res_masks": low_res_masks,
"masks": masks,
"iou_predictions": iou_predictions
}
)
if self.upsampled_masks:
masks = self.sam.postprocess_masks(
low_res_masks,
input_size=image_record["input_size"],
original_size=image_record["original_size"],
)
curr_outputs["masks"] = masks

outputs.append(curr_outputs)

return outputs

0 comments on commit c5777be

Please sign in to comment.