Skip to content

Commit

Permalink
add 3d crop
Browse files Browse the repository at this point in the history
  • Loading branch information
alisterburt committed Jun 16, 2024
1 parent 8dc7928 commit 1253df6
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 2 deletions.
17 changes: 17 additions & 0 deletions examples/extract_from_single_3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import numpy as np
import torch

from torch_subpixel_crop import subpixel_crop_3d
from skimage import data

image = torch.tensor(data.binary_blobs(length=128, n_dim=3)).float()
positions = torch.tensor(np.random.uniform(low=0, high=127, size=(100, 3))).float()

crops = subpixel_crop_3d(
image=image,
positions=positions,
sidelength=32
)

# (100, 32, 32, 32)
print(crops.shape)
1 change: 1 addition & 0 deletions src/torch_subpixel_crop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
__email__ = "[email protected]"

from .subpixel_crop_2d import subpixel_crop_2d
from .subpixel_crop_3d import subpixel_crop_3d
7 changes: 5 additions & 2 deletions src/torch_subpixel_crop/subpixel_crop_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@
def subpixel_crop_2d(
image: torch.Tensor, positions: torch.Tensor, sidelength: int,
):
"""Extract square patches from 2D images at positions with subpixel precision.
"""Extract square patches from 2D images with subpixel precision.
Patches are extracted at the nearest integer coordinates then phase shifted
such that the requested position is at the center of the patch.
The center of an image is defined to be the position of the DC component of an
fftshifted discrete Fourier transform.
Parameters
----------
image: torch.Tensor
Expand All @@ -29,7 +32,7 @@ def subpixel_crop_2d(
-------
patches: torch.Tensor
`(..., b, sidelength, sidelength)` or `(..., sidelength, sidelength)` array
of patches from `images` with their centers at `positions`.
of patches from `image` with their centers at `positions`.
"""
# handling batched input
if image.ndim == 2:
Expand Down
71 changes: 71 additions & 0 deletions src/torch_subpixel_crop/subpixel_crop_3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import numpy as np
import torch
import einops
import torch.nn.functional as F
from skimage import data
from torch_fourier_shift import fourier_shift_image_3d
from torch_grid_utils import coordinate_grid

from torch_subpixel_crop.dft_utils import dft_center
from torch_subpixel_crop.grid_sample_utils import array_to_grid_sample


def subpixel_crop_3d(
image: torch.Tensor, # (d, h, w)
positions: torch.Tensor, # (b, 3) zyx
sidelength: int,
) -> torch.Tensor:
"""Extract cubic patches from a 3D image with subpixel precision.
Patches are extracted at the nearest integer coordinates then phase shifted
such that the requested position is at the center of the patch.
The center of an image is defined to be the position of the DC component of an
fftshifted discrete Fourier transform.
Parameters
----------
image: torch.Tensor
`(d, h, w)` array containing the volume.
positions: torch.Tensor
`(b, 3)` array of coordinates for patch centers.
sidelength: int
Sidelength of cubic patches extracted from `image`.
Returns
-------
patches: torch.Tensor
`(b, sidelength, sidelength, sidelength)` array of cropped regions from `volume`
with their centers at `positions`.
"""
d, h, w = image.shape
b, _ = positions.shape

# find integer positions and shifts to be applied
integer_positions = torch.round(positions)
shifts = integer_positions - positions

# generate coordinate grids for sampling around each integer position
pd, ph, pw = (sidelength, sidelength, sidelength)
center = dft_center((pd, ph, pw), rfft=False, fftshifted=True, device=image.device)
grid = coordinate_grid(
image_shape=(pd, ph, pw),
center=center,
device=image.device
) # (d, h, w, 2)
broadcastable_positions = einops.rearrange(integer_positions, 'b zyx -> b 1 1 1 zyx')
grid = grid + broadcastable_positions # (b, d, h, w, 3)

# extract patches, grid sample handles boundaries
patches = F.grid_sample(
input=einops.repeat(image, 'd h w -> b 1 d h w', b=b),
grid=array_to_grid_sample(grid, array_shape=(d, h, w)),
mode='nearest',
padding_mode='zeros',
align_corners=True
)
patches = einops.rearrange(patches, 'b 1 d h w -> b d h w')

# phase shift to center images
patches = fourier_shift_image_3d(image=patches, shifts=shifts)
return patches

0 comments on commit 1253df6

Please sign in to comment.