From 461ab5ae4dfe00dd965866d7233866c0594914e3 Mon Sep 17 00:00:00 2001 From: lufre1 <155526548+lufre1@users.noreply.github.com> Date: Tue, 15 Oct 2024 14:14:55 +0200 Subject: [PATCH] 378 add ensuring patch shape for raw dataset (#379) added eunsre patch shape function in raw dataset --- torch_em/data/raw_dataset.py | 10 ++++++++-- torch_em/util/util.py | 13 ++++++++----- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/torch_em/data/raw_dataset.py b/torch_em/data/raw_dataset.py index ec410718..402659d1 100644 --- a/torch_em/data/raw_dataset.py +++ b/torch_em/data/raw_dataset.py @@ -7,7 +7,7 @@ from elf.wrapper import RoiWrapper -from ..util import ensure_tensor_with_channels, load_data +from ..util import ensure_tensor_with_channels, ensure_patch_shape, load_data class RawDataset(torch.utils.data.Dataset): @@ -107,7 +107,13 @@ def _get_sample(self, index): sample_id += 1 if sample_id > self.max_sampling_attempts: raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts") - + if self.patch_shape is not None: + raw = ensure_patch_shape( + raw=raw, + labels=None, + patch_shape=self.patch_shape, + have_raw_channels=self._with_channels + ) # squeeze the singleton spatial axis if we have a spatial shape that is larger by one than self._ndim if len(self.patch_shape) == self._ndim + 1: raw = raw.squeeze(1 if self._with_channels else 0) diff --git a/torch_em/util/util.py b/torch_em/util/util.py index 7f9a777f..e4062e56 100644 --- a/torch_em/util/util.py +++ b/torch_em/util/util.py @@ -149,14 +149,15 @@ def ensure_patch_shape( raw, labels, patch_shape, have_raw_channels=False, have_label_channels=False, channel_first=True ): raw_shape = raw.shape - labels_shape = labels.shape + if labels is not None: + labels_shape = labels.shape # In case the inputs has channels and they are channels first # IMPORTANT: for ImageCollectionDataset if have_raw_channels and channel_first: raw_shape = raw_shape[1:] - if have_label_channels and channel_first: + if have_label_channels and channel_first and labels is not None: labels_shape = labels_shape[1:] # Extract the pad_width and pad the raw inputs @@ -173,7 +174,7 @@ def ensure_patch_shape( raw = np.pad(array=raw, pad_width=pad_width) # Extract the pad width and pad the label inputs - if any(sh < psh for sh, psh in zip(labels_shape, patch_shape)): + if labels is not None and any(sh < psh for sh, psh in zip(labels_shape, patch_shape)): pw = [(0, max(0, psh - sh)) for sh, psh in zip(labels_shape, patch_shape)] if have_label_channels and channel_first: @@ -184,8 +185,10 @@ def ensure_patch_shape( pad_width = pw labels = np.pad(array=labels, pad_width=pad_width) - - return raw, labels + if labels is None: + return raw + else: + return raw, labels def get_constructor_arguments(obj):