Skip to content

Commit

Permalink
378 add ensuring patch shape for raw dataset (#379)
Browse files Browse the repository at this point in the history
added eunsre patch shape function in raw dataset
  • Loading branch information
lufre1 authored Oct 15, 2024
1 parent d42d50c commit 461ab5a
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
10 changes: 8 additions & 2 deletions torch_em/data/raw_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from elf.wrapper import RoiWrapper

from ..util import ensure_tensor_with_channels, load_data
from ..util import ensure_tensor_with_channels, ensure_patch_shape, load_data


class RawDataset(torch.utils.data.Dataset):
Expand Down Expand Up @@ -107,7 +107,13 @@ def _get_sample(self, index):
sample_id += 1
if sample_id > self.max_sampling_attempts:
raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")

if self.patch_shape is not None:
raw = ensure_patch_shape(
raw=raw,
labels=None,
patch_shape=self.patch_shape,
have_raw_channels=self._with_channels
)
# squeeze the singleton spatial axis if we have a spatial shape that is larger by one than self._ndim
if len(self.patch_shape) == self._ndim + 1:
raw = raw.squeeze(1 if self._with_channels else 0)
Expand Down
13 changes: 8 additions & 5 deletions torch_em/util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,15 @@ def ensure_patch_shape(
raw, labels, patch_shape, have_raw_channels=False, have_label_channels=False, channel_first=True
):
raw_shape = raw.shape
labels_shape = labels.shape
if labels is not None:
labels_shape = labels.shape

# In case the inputs has channels and they are channels first
# IMPORTANT: for ImageCollectionDataset
if have_raw_channels and channel_first:
raw_shape = raw_shape[1:]

if have_label_channels and channel_first:
if have_label_channels and channel_first and labels is not None:
labels_shape = labels_shape[1:]

# Extract the pad_width and pad the raw inputs
Expand All @@ -173,7 +174,7 @@ def ensure_patch_shape(
raw = np.pad(array=raw, pad_width=pad_width)

# Extract the pad width and pad the label inputs
if any(sh < psh for sh, psh in zip(labels_shape, patch_shape)):
if labels is not None and any(sh < psh for sh, psh in zip(labels_shape, patch_shape)):
pw = [(0, max(0, psh - sh)) for sh, psh in zip(labels_shape, patch_shape)]

if have_label_channels and channel_first:
Expand All @@ -184,8 +185,10 @@ def ensure_patch_shape(
pad_width = pw

labels = np.pad(array=labels, pad_width=pad_width)

return raw, labels
if labels is None:
return raw
else:
return raw, labels


def get_constructor_arguments(obj):
Expand Down

0 comments on commit 461ab5a

Please sign in to comment.