diff --git a/micro_sam/training/training.py b/micro_sam/training/training.py index 6c66ccb3..9fdf9ede 100644 --- a/micro_sam/training/training.py +++ b/micro_sam/training/training.py @@ -389,6 +389,7 @@ def default_sam_dataset( is_train: bool = True, min_size: int = 25, max_sampling_attempts: Optional[int] = None, + is_seg_dataset: Optional[bool] = None, **kwargs, ) -> Dataset: """Create a PyTorch Dataset for training a SAM model. @@ -412,6 +413,8 @@ def default_sam_dataset( is_train: Whether this dataset is used for training or validation. min_size: Minimal object size. Smaller objects will be filtered. max_sampling_attempts: Number of sampling attempts to make from a dataset. + is_seg_dataset: Whether the dataset is built 'from torch_em.data import SegmentationDataset' + or 'from torch_em.data import ImageCollectionDataset' Returns: The dataset. @@ -443,8 +446,8 @@ def default_sam_dataset( # Set a minimum number of samples per epoch. if n_samples is None: loader = torch_em.default_segmentation_loader( - raw_paths, raw_key, label_paths, label_key, - batch_size=1, patch_shape=patch_shape, ndim=2 + raw_paths, raw_key, label_paths, label_key, batch_size=1, + patch_shape=patch_shape, ndim=2, is_seg_dataset=is_seg_dataset, ) n_samples = max(len(loader), 100 if is_train else 5) @@ -454,6 +457,7 @@ def default_sam_dataset( raw_transform=raw_transform, label_transform=label_transform, with_channels=with_channels, ndim=2, sampler=sampler, n_samples=n_samples, + is_seg_dataset=is_seg_dataset, **kwargs, )