Skip to content

Commit

Permalink
Merge branch 'dev' into gdsdataset
Browse files Browse the repository at this point in the history
  • Loading branch information
wyli authored Jul 27, 2023
2 parents d3c240e + 9e2d381 commit 94448b0
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 45 deletions.
96 changes: 51 additions & 45 deletions monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def update_meta(rets: Sequence, func, args, kwargs) -> Sequence:
:py:func:`MetaTensor._copy_meta`).
"""
out = []
metas = None
metas = None # optional output metadicts for each of the return value in `rets`
is_batch = any(x.is_batch for x in MetaObj.flatten_meta_objs(args, kwargs.values()) if hasattr(x, "is_batch"))
for idx, ret in enumerate(rets):
# if not `MetaTensor`, nothing to do.
Expand All @@ -219,55 +219,61 @@ def update_meta(rets: Sequence, func, args, kwargs) -> Sequence:
# the following is not implemented but the network arch may run into this case:
# if func == torch.cat and any(m.is_batch if hasattr(m, "is_batch") else False for m in meta_args):
# raise NotImplementedError("torch.cat is not implemented for batch of MetaTensors.")

# If we have a batch of data, then we need to be careful if a slice of
# the data is returned. Depending on how the data are indexed, we return
# some or all of the metadata, and the return object may or may not be a
# batch of data (e.g., `batch[:,-1]` versus `batch[0]`).
if is_batch:
# if indexing e.g., `batch[0]`
if func == torch.Tensor.__getitem__:
batch_idx = args[1]
if isinstance(batch_idx, Sequence):
batch_idx = batch_idx[0]
# if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the
# first element will be `slice(None, None, None)` and `Ellipsis`,
# respectively. Don't need to do anything with the metadata.
if batch_idx not in (slice(None, None, None), Ellipsis, None) and idx == 0:
ret_meta = decollate_batch(args[0], detach=False)[batch_idx]
if isinstance(ret_meta, list) and ret_meta: # e.g. batch[0:2], re-collate
try:
ret_meta = list_data_collate(ret_meta)
except (TypeError, ValueError, RuntimeError, IndexError) as e:
raise ValueError(
"Inconsistent batched metadata dicts when slicing a batch of MetaTensors, "
"please convert it into a torch Tensor using `x.as_tensor()` or "
"a numpy array using `x.array`."
) from e
elif isinstance(ret_meta, MetaObj): # e.g. `batch[0]` or `batch[0, 1]`, batch_idx is int
ret_meta.is_batch = False
if hasattr(ret_meta, "__dict__"):
ret.__dict__ = ret_meta.__dict__.copy()
# `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`.
# But we only want to split the batch if the `unbind` is along the 0th
# dimension.
elif func == torch.Tensor.unbind:
if len(args) > 1:
dim = args[1]
elif "dim" in kwargs:
dim = kwargs["dim"]
else:
dim = 0
if dim == 0:
if metas is None:
metas = decollate_batch(args[0], detach=False)
ret.__dict__ = metas[idx].__dict__.copy()
ret.is_batch = False

ret = MetaTensor._handle_batched(ret, idx, metas, func, args, kwargs)
out.append(ret)
# if the input was a tuple, then return it as a tuple
return tuple(out) if isinstance(rets, tuple) else out

@classmethod
def _handle_batched(cls, ret, idx, metas, func, args, kwargs):
"""utility function to handle batched MetaTensors."""
# If we have a batch of data, then we need to be careful if a slice of
# the data is returned. Depending on how the data are indexed, we return
# some or all of the metadata, and the return object may or may not be a
# batch of data (e.g., `batch[:,-1]` versus `batch[0]`).
# if indexing e.g., `batch[0]`
if func == torch.Tensor.__getitem__:
if idx > 0 or len(args) < 2 or len(args[0]) < 1:
return ret
batch_idx = args[1][0] if isinstance(args[1], Sequence) else args[1]
# if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the
# first element will be `slice(None, None, None)` and `Ellipsis`,
# respectively. Don't need to do anything with the metadata.
if batch_idx in (slice(None, None, None), Ellipsis, None) or isinstance(batch_idx, torch.Tensor):
return ret
dec_batch = decollate_batch(args[0], detach=False)
ret_meta = dec_batch[batch_idx]
if isinstance(ret_meta, list) and ret_meta: # e.g. batch[0:2], re-collate
try:
ret_meta = list_data_collate(ret_meta)
except (TypeError, ValueError, RuntimeError, IndexError) as e:
raise ValueError(
"Inconsistent batched metadata dicts when slicing a batch of MetaTensors, "
"please consider converting it into a torch Tensor using `x.as_tensor()` or "
"a numpy array using `x.array`."
) from e
elif isinstance(ret_meta, MetaObj): # e.g. `batch[0]` or `batch[0, 1]`, batch_idx is int
ret_meta.is_batch = False
if hasattr(ret_meta, "__dict__"):
ret.__dict__ = ret_meta.__dict__.copy()
# `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`.
# But we only want to split the batch if the `unbind` is along the 0th dimension.
elif func == torch.Tensor.unbind:
if len(args) > 1:
dim = args[1]
elif "dim" in kwargs:
dim = kwargs["dim"]
else:
dim = 0
if dim == 0:
if metas is None:
metas = decollate_batch(args[0], detach=False)
if hasattr(metas[idx], "__dict__"):
ret.__dict__ = metas[idx].__dict__.copy()
ret.is_batch = False
return ret

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None) -> Any:
"""Wraps all torch functions."""
Expand Down
4 changes: 4 additions & 0 deletions tests/test_meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,10 @@ def test_slicing(self):
x.is_batch = True
with self.assertRaises(ValueError):
x[slice(0, 8)]
x = MetaTensor(np.zeros((3, 3, 4)))
x.is_batch = True
self.assertEqual(x[torch.tensor([True, False, True])].shape, (2, 3, 4))
self.assertEqual(x[[True, False, True]].shape, (2, 3, 4))

@parameterized.expand(DTYPES)
@SkipIfBeforePyTorchVersion((1, 8))
Expand Down

0 comments on commit 94448b0

Please sign in to comment.