diff --git a/micro_sam/training/sam_trainer.py b/micro_sam/training/sam_trainer.py index c95d4227..ca53f496 100644 --- a/micro_sam/training/sam_trainer.py +++ b/micro_sam/training/sam_trainer.py @@ -118,6 +118,24 @@ def _compute_iou(self, pred, true, eps=1e-7): iou = overlap / (union + eps) return iou + def preprocess_one_hot_masks(self, y_one_hot): + """ + """ + # Convert the labels to "low_res_mask" shape + # First step is to use the logic from `ResizeLongestSide` to resize the longest side. + target_length = self.model.transform.target_length + target_shape = self.model.transform.get_preprocess_shape(y_one_hot.shape[2], y_one_hot.shape[3], target_length) + y_one_hot = F.interpolate(input=y_one_hot, size=target_shape) + # Next, we pad the remaining region to (1024, 1024) + h, w = y_one_hot.shape[-2:] + padh = self.model.sam.image_encoder.img_size - h + padw = self.model.sam.image_encoder.img_size - w + y_one_hot = F.pad(input=y_one_hot, pad=(0, padw, 0, padh)) + # Finally, let's resize the labels to the desired shape (i.e. (256, 256)) + y_one_hot = F.interpolate(input=y_one_hot, size=(256, 256)) + + return y_one_hot + def _compute_loss(self, batched_outputs, y_one_hot): """Compute the loss for one iteration. The loss is made up of two components: - The mask loss: dice score between the predicted masks and targets. @@ -125,6 +143,9 @@ def _compute_loss(self, batched_outputs, y_one_hot): """ mask_loss, iou_regression_loss = 0.0, 0.0 + # TODO + y_one_hot = self.preprocess_one_hot_masks(y_one_hot) + # Loop over the batch. for batch_output, targets in zip(batched_outputs, y_one_hot): predicted_objects = torch.sigmoid(batch_output["low_res_masks"]) @@ -276,22 +297,6 @@ def _preprocess_batch(self, batched_inputs, y, sampled_ids): # number of objects across the batch. n_objects = min(len(ids) for ids in sampled_ids) - original_instance_ids = list(torch.unique(y)) - # Convert the labels to "low_res_mask" shape - # First step is to use the logic from `ResizeLongestSide` to resize the longest side. - target_length = self.model.transform.target_length - target_shape = self.model.transform.get_preprocess_shape(y.shape[2], y.shape[3], target_length) - y = F.interpolate(input=y, size=target_shape) - # Next, we pad the remaining region to (1024, 1024) - h, w = y.shape[-2:] - padh = self.model.sam.image_encoder.img_size - h - padw = self.model.sam.image_encoder.img_size - w - y = F.pad(input=y, pad=(0, padw, 0, padh)) - # Finally, let's resize the labels to the desired shape (i.e. (256, 256)) - y = F.interpolate(input=y, size=(256, 256)) - - assert list(torch.unique(y)) == original_instance_ids - y = y.to(self.device) # Compute the one hot targets for the seg-id. y_one_hot = torch.stack([