Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Invertable affine augmentations #10

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions scripts/augmentation/check_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import h5py
import kornia

from torch_em.transform.augmentation import KorniaAugmentationPipeline, get_augmentations
from torch_em.transform.augmentation import AugmentationPipeline, get_augmentations
from torch_em.transform.augmentation import RandomElasticDeformation

pr = '/g/schwab/hennies/project_segmentation_paper/ds_sbem-6dpf-1-whole/seg_210122_mito/seg_10nm/gt_cubes/gt000/raw_256.h5'
Expand All @@ -21,7 +21,7 @@ def check_kornia_augmentation():
degrees=90., p=1.
)

trafo = KorniaAugmentationPipeline(
trafo = AugmentationPipeline(
rot
)

Expand Down
6 changes: 4 additions & 2 deletions torch_em/data/segmentation_dataset.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import torch
import numpy as np
import torch
from elf.io import open_file
from elf.wrapper import RoiWrapper
from torch.utils.data import Dataset

from ..util import ensure_tensor_with_channels


class SegmentationDataset(torch.utils.data.Dataset):
class SegmentationDataset(Dataset):
"""
"""

max_sampling_attempts = 500

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions torch_em/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def _get_paths(rpath, rkey, lpath, lkey):
return ds


def _get_default_transform(path, key, is_seg_dataset):
def _get_default_transform(path, key, is_seg_dataset, return_transforms: bool = False):
if is_seg_dataset:
with open_file(path, mode='r') as f:
shape = f[key].shape
Expand All @@ -173,7 +173,7 @@ def _get_default_transform(path, key, is_seg_dataset):
ndim = 'anisotropic' if shape[0] < shape[1] // 2 else 3
else:
ndim = 2
return get_augmentations(ndim)
return get_augmentations(ndim, return_transforms=return_transforms)


def default_segmentation_loader(
Expand Down
159 changes: 94 additions & 65 deletions torch_em/transform/augmentation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import torch
from typing import List, Optional, Sequence, Tuple, Union

import kornia.augmentation.utils
import numpy as np
import kornia
FynnBe marked this conversation as resolved.
Show resolved Hide resolved
import torch
from kornia import warp_affine3d
from kornia.augmentation import AugmentationBase2D, AugmentationBase3D
from kornia.augmentation.base import _AugmentationBase as Augmentation
from skimage.transform import resize

from ..util import ensure_tensor
Expand All @@ -9,18 +14,18 @@
# TODO RandomElastic3D ?


class RandomElasticDeformation(kornia.augmentation.AugmentationBase2D):
FynnBe marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self,
control_point_spacing=1,
sigma=(4., 4.),
alpha=(32., 32.),
resample=kornia.constants.Resample.BILINEAR,
p=0.5,
keepdim=False,
same_on_batch=False):
super().__init__(p=p, # keepdim=keepdim,
same_on_batch=same_on_batch,
return_transform=False)
class RandomElasticDeformation(AugmentationBase2D):
def __init__(
self,
control_point_spacing: Union[int, Sequence[int]] = 1,
sigma=(4.0, 4.0),
alpha=(32.0, 32.0),
resample=kornia.constants.Resample.BILINEAR,
p=0.5,
keepdim=False,
same_on_batch=False,
):
super().__init__(p=p, same_on_batch=same_on_batch, return_transform=False) # keepdim=keepdim,
if isinstance(control_point_spacing, int):
self.control_point_spacing = [control_point_spacing] * 2
else:
Expand Down Expand Up @@ -63,13 +68,18 @@ def apply_transform(self, input, params):

# TODO implement 'require_halo', and estimate the halo from the transformations
# so that we can load a bigger block and cut it away
class KorniaAugmentationPipeline(torch.nn.Module):
interpolatable_torch_tpyes = [torch.float16, torch.float32, torch.float64]
interpolatable_numpy_types = [np.dtype('float32'), np.dtype('float64')]
class AugmentationPipeline(torch.nn.Module):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think "KorniaAugmentationPipeline" makes more sense as a name, because it's not for generic Augmentations.

interpolatable_torch_types = [torch.float16, torch.float32, torch.float64]
interpolatable_numpy_types = [np.dtype("float32"), np.dtype("float64")]

def __init__(self, *kornia_augmentations, dtype=torch.float32):
def __init__(self, *augmentations: Augmentation, return_transform: bool = False, dtype=torch.float32):
super().__init__()
self.augmentations = torch.nn.ModuleList(kornia_augmentations)
self.return_transform = return_transform
self.is3D = any(isinstance(aug, AugmentationBase3D) for aug in augmentations)
for aug in augmentations:
aug.return_transform = return_transform

self.augmentations: Sequence[Augmentation] = torch.nn.ModuleList(augmentations) # type: ignore
self.dtype = dtype
self.halo = self.compute_halo()

Expand All @@ -90,32 +100,72 @@ def is_interpolatable(self, tensor):
else:
return tensor.dtype in self.interpolatable_numpy_types

def transform_tensor(self, augmentation, tensor, interpolatable, params=None):
interpolating = 'interpolation' in getattr(augmentation, 'flags', [])
def _configure_augmentation(self, augmentation: Augmentation, interpolatable):
interpolating = "interpolation" in getattr(augmentation, "flags", [])
if interpolating:
resampler = kornia.constants.Resample.get('BILINEAR' if interpolatable else 'NEAREST')
augmentation.flags['interpolation'] = torch.tensor(resampler.value)
resampler = kornia.constants.Resample.get("BILINEAR" if interpolatable else "NEAREST")
augmentation.flags["interpolation"] = torch.tensor(resampler.value)

transformed = augmentation(tensor, params)
return transformed, augmentation._params
def _get_eye(self, tensor: torch.Tensor):
return kornia.eye_like(4 if self.is3D else 3, tensor)

def forward(self, *tensors):
def ensure_batch_tensor(self, tensor: torch.Tensor):
ensure_batch = (
kornia.augmentation.utils.helpers._transform_input3d
if self.is3D
else kornia.augmentation.utils.helpers._transform_input
)
return ensure_batch(ensure_tensor(tensor, self.dtype))

def forward(
self, *tensors, trans_matrices: Optional[Sequence[torch.Tensor]] = None
) -> Union[List[torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor]]]:
interpolatable = [self.is_interpolatable(tensor) for tensor in tensors]
tensors = [ensure_tensor(tensor, self.dtype) for tensor in tensors]
for aug in self.augmentations:
tensors = [self.ensure_batch_tensor(tensor) for tensor in tensors]

if trans_matrices is None:
trans_matrices = [self._get_eye(t) for t in tensors]
else:
trans_matrices = list(trans_matrices)

t0, params = self.transform_tensor(aug, tensors[0], interpolatable[0])
transformed_tensors = [t0]
for tensor, interpolate in zip(tensors[1:], interpolatable[1:]):
tensor, _ = self.transform_tensor(aug, tensor, interpolate, params=params)
transformed_tensors.append(tensor)
assert len(trans_matrices) == len(tensors)
for ti, tensor in enumerate(tensors):
if trans_matrices[ti] is None:
trans_matrices[ti] = self._get_eye(tensor)

tensors = transformed_tensors
return tensors
transformed_tensors = []
all_params = [None] * len(self.augmentations)
for ti, (tensor, interpolate) in enumerate(zip(tensors, interpolatable)):
for ai, aug in enumerate(self.augmentations):
self._configure_augmentation(aug, interpolatable)
tensor, trans_matrices[ti] = aug((tensor, trans_matrices[ti]), all_params[ai])

if all_params[ai] is None:
all_params[ai] = aug._params
else:
assert all_params[ai] == aug._params

transformed_tensors.append(tensor)

return (transformed_tensors, trans_matrices) if self.return_transform else transformed_tensors

def halo(self, shape):
return self.halo

def apply_inverse(self, *tensors: torch.Tensor, forward_transforms: Sequence[torch.Tensor], padding_mode="border"):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only works for affine trafos / anything that can be expressed as affine trafo. But that's certainly not the case for all augmentations we have, e.g. elastic deformations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renamed it to apply_inverse_affine to make it explicit that it applies only to transformations that can be expressed as an affine transformation.

assert len(tensors) == len(forward_transforms)
trans_matrices = torch.linalg.inv(torch.stack(list(forward_transforms)))
return [
warp_affine3d(
src=t,
M=m,
dsize=t.shape[-3:],
flags="bliinear" if self.is_interpolatable(t) else "nearest",
padding_mode=padding_mode,
)
for t, m in zip(tensors, trans_matrices)
]


# TODO elastic deformation
# Try out:
Expand All @@ -133,25 +183,12 @@ def halo(self, shape):
}


DEFAULT_2D_AUGMENTATIONS = [
"RandomHorizontalFlip",
"RandomVerticalFlip"
]
DEFAULT_3D_AUGMENTATIONS = [
"RandomHorizontalFlip3D",
"RandomVerticalFlip3D",
"RandomDepthicalFlip3D",
]
DEFAULT_ANISOTROPIC_AUGMENTATIONS = [
"RandomHorizontalFlip3D",
"RandomVerticalFlip3D",
"RandomDepthicalFlip3D",
]


def get_augmentations(ndim=2,
transforms=None,
dtype=torch.float32):
DEFAULT_2D_AUGMENTATIONS = ["RandomHorizontalFlip", "RandomVerticalFlip"]
DEFAULT_3D_AUGMENTATIONS = ["RandomHorizontalFlip3D", "RandomVerticalFlip3D", "RandomDepthicalFlip3D"]
DEFAULT_ANISOTROPIC_AUGMENTATIONS = ["RandomHorizontalFlip3D", "RandomVerticalFlip3D", "RandomDepthicalFlip3D"]


def get_augmentations(ndim=2, transforms=None, dtype=torch.float32, return_transforms: bool = False):
if transforms is None:
assert ndim in (2, 3, "anisotropic"), f"Expect ndim to be one of (2, 3, 'anisotropic'), got {ndim}"
if ndim == 2:
Expand All @@ -160,15 +197,7 @@ def get_augmentations(ndim=2,
transforms = DEFAULT_3D_AUGMENTATIONS
else:
transforms = DEFAULT_ANISOTROPIC_AUGMENTATIONS
transforms = [
getattr(kornia.augmentation, trafo)(**AUGMENTATIONS[trafo])
for trafo in transforms
]
transforms = [getattr(kornia.augmentation, trafo)(**AUGMENTATIONS[trafo]) for trafo in transforms]

assert all(isinstance(trafo, kornia.augmentation.base._AugmentationBase)
for trafo in transforms)
augmentations = KorniaAugmentationPipeline(
*transforms,
dtype=dtype
)
return augmentations
assert all(isinstance(trafo, Augmentation) for trafo in transforms)
return AugmentationPipeline(*transforms, dtype=dtype, return_transform=return_transforms)
12 changes: 8 additions & 4 deletions torch_em/util/modelzoo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools
import json
import os
import pathlib
import subprocess
from shutil import copyfile
from warnings import warn
Expand Down Expand Up @@ -192,9 +193,12 @@ def _default_authors():

def _default_repo():
try:
call_res = subprocess.run(['git', 'remote', '-v'], capture_output=True)
repo = call_res.stdout.decode('utf8').split('\n')[0].split()[1]
repo = repo if repo else None
call_res = subprocess.run(["git", "remote", "-v"], capture_output=True)
repo = call_res.stdout.decode("utf8").split("\n")[0].split()[1]
if repo:
repo = repo.replace("[email protected]:", "https://github.com/")
else:
repo = None
except Exception:
repo = None
return repo
Expand Down Expand Up @@ -679,7 +683,7 @@ def import_bioimageio_model(spec_path, return_spec=False):
# to the source for the bioimageio package
if spec is None:
raise RuntimeError("Need bioimageio package")
bio_spec = spec.load_and_resolve_spec(os.path.abspath(spec_path))
bio_spec = spec.load_and_resolve_spec(pathlib.Path(spec_path).absolute())

model = _load_model(bio_spec)
normalizer = _load_normalizer(bio_spec)
Expand Down
11 changes: 9 additions & 2 deletions torch_em/util/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,15 @@

def create_segmentation_test_data(data_path, raw_key, label_key, shape, chunks):
with h5py.File(data_path, 'a') as f:
f.create_dataset(raw_key, data=np.random.rand(*shape), chunks=chunks)
f.create_dataset(label_key, data=np.random.randint(0, 4, size=shape), chunks=chunks)
try:
f.create_dataset(raw_key, data=np.random.rand(*shape), chunks=chunks)
except ValueError: # Unable to create dataset (name already exists)
pass

try:
f.create_dataset(label_key, data=np.random.randint(0, 4, size=shape), chunks=chunks)
except ValueError: # Unable to create dataset (name already exists)
pass


def create_image_collection_test_data(folder, n_images, min_shape, max_shape):
Expand Down