-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
8dc7928
commit 1253df6
Showing
4 changed files
with
94 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import numpy as np | ||
import torch | ||
|
||
from torch_subpixel_crop import subpixel_crop_3d | ||
from skimage import data | ||
|
||
image = torch.tensor(data.binary_blobs(length=128, n_dim=3)).float() | ||
positions = torch.tensor(np.random.uniform(low=0, high=127, size=(100, 3))).float() | ||
|
||
crops = subpixel_crop_3d( | ||
image=image, | ||
positions=positions, | ||
sidelength=32 | ||
) | ||
|
||
# (100, 32, 32, 32) | ||
print(crops.shape) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,3 +10,4 @@ | |
__email__ = "[email protected]" | ||
|
||
from .subpixel_crop_2d import subpixel_crop_2d | ||
from .subpixel_crop_3d import subpixel_crop_3d |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import numpy as np | ||
import torch | ||
import einops | ||
import torch.nn.functional as F | ||
from skimage import data | ||
from torch_fourier_shift import fourier_shift_image_3d | ||
from torch_grid_utils import coordinate_grid | ||
|
||
from torch_subpixel_crop.dft_utils import dft_center | ||
from torch_subpixel_crop.grid_sample_utils import array_to_grid_sample | ||
|
||
|
||
def subpixel_crop_3d( | ||
image: torch.Tensor, # (d, h, w) | ||
positions: torch.Tensor, # (b, 3) zyx | ||
sidelength: int, | ||
) -> torch.Tensor: | ||
"""Extract cubic patches from a 3D image with subpixel precision. | ||
Patches are extracted at the nearest integer coordinates then phase shifted | ||
such that the requested position is at the center of the patch. | ||
The center of an image is defined to be the position of the DC component of an | ||
fftshifted discrete Fourier transform. | ||
Parameters | ||
---------- | ||
image: torch.Tensor | ||
`(d, h, w)` array containing the volume. | ||
positions: torch.Tensor | ||
`(b, 3)` array of coordinates for patch centers. | ||
sidelength: int | ||
Sidelength of cubic patches extracted from `image`. | ||
Returns | ||
------- | ||
patches: torch.Tensor | ||
`(b, sidelength, sidelength, sidelength)` array of cropped regions from `volume` | ||
with their centers at `positions`. | ||
""" | ||
d, h, w = image.shape | ||
b, _ = positions.shape | ||
|
||
# find integer positions and shifts to be applied | ||
integer_positions = torch.round(positions) | ||
shifts = integer_positions - positions | ||
|
||
# generate coordinate grids for sampling around each integer position | ||
pd, ph, pw = (sidelength, sidelength, sidelength) | ||
center = dft_center((pd, ph, pw), rfft=False, fftshifted=True, device=image.device) | ||
grid = coordinate_grid( | ||
image_shape=(pd, ph, pw), | ||
center=center, | ||
device=image.device | ||
) # (d, h, w, 2) | ||
broadcastable_positions = einops.rearrange(integer_positions, 'b zyx -> b 1 1 1 zyx') | ||
grid = grid + broadcastable_positions # (b, d, h, w, 3) | ||
|
||
# extract patches, grid sample handles boundaries | ||
patches = F.grid_sample( | ||
input=einops.repeat(image, 'd h w -> b 1 d h w', b=b), | ||
grid=array_to_grid_sample(grid, array_shape=(d, h, w)), | ||
mode='nearest', | ||
padding_mode='zeros', | ||
align_corners=True | ||
) | ||
patches = einops.rearrange(patches, 'b 1 d h w -> b d h w') | ||
|
||
# phase shift to center images | ||
patches = fourier_shift_image_3d(image=patches, shifts=shifts) | ||
return patches |