diff --git a/micro_sam/training/util.py b/micro_sam/training/util.py index 56024af1..52096462 100644 --- a/micro_sam/training/util.py +++ b/micro_sam/training/util.py @@ -2,7 +2,6 @@ from typing import List, Optional, Union import numpy as np -import torch from ..prompt_generators import PointAndBoxPromptGenerator from ..util import get_centers_and_bounding_boxes, get_sam_model, segmentation_to_one_hot, _get_device @@ -86,8 +85,8 @@ def _get_prompt_lists(self, gt, n_samples, prompt_generator): sampled_cell_ids = np.random.choice(cell_ids, size=min(n_samples, len(cell_ids)), replace=False) sampled_cell_ids = np.sort(sampled_cell_ids) - # only keep the bounding boxes for sampled cell ids - bbox_coordinates = [bbox_coordinates[sampled_id] for sampled_id in sampled_cell_ids] + # only keep the bounding boxes for sampled cell ids + bbox_coordinates = [bbox_coordinates[sampled_id] for sampled_id in sampled_cell_ids] # convert the gt to the one-hot-encoded masks for the sampled cell ids object_masks = segmentation_to_one_hot(gt, None if n_samples is None else sampled_cell_ids)