Skip to content

Commit

Permalink
1D line extraction from 2D images (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
McHaillet authored Nov 16, 2024
1 parent 093c87e commit bf1588b
Show file tree
Hide file tree
Showing 8 changed files with 213 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/torch_fourier_slice/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
__author__ = "Alister Burt"
__email__ = "[email protected]"

from .project import project_3d_to_2d
from .project import project_3d_to_2d, project_2d_to_1d
from .backproject import backproject_2d_to_3d
from .slice_insertion import insert_central_slices_rfft_3d
from .slice_extraction import extract_central_slices_rfft_3d
12 changes: 10 additions & 2 deletions src/torch_fourier_slice/dft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ def rfft_shape(input_shape: Sequence[int]) -> Tuple[int, ...]:
return tuple(rfft_shape)


def fftshift_1d(input: torch.Tensor, rfft: bool) -> torch.Tensor:
if rfft is False:
output = torch.fft.fftshift(input, dim=(-1))
else:
output = input
return output


def fftshift_2d(input: torch.Tensor, rfft: bool) -> torch.Tensor:
if rfft is False:
output = torch.fft.fftshift(input, dim=(-2, -1))
Expand Down Expand Up @@ -88,9 +96,9 @@ def dft_center(
fft_center = torch.zeros(size=(len(image_shape),), device=device)
image_shape = torch.as_tensor(image_shape).float()
if rfft is True:
image_shape = torch.tensor(rfft_shape(image_shape))
image_shape = torch.tensor(rfft_shape(image_shape), device=device)
if fftshifted is True:
fft_center = torch.divide(image_shape, 2, rounding_mode='floor')
fft_center = torch.divide(image_shape, 2, rounding_mode="floor")
if rfft is True:
fft_center[-1] = 0
return fft_center.long()
3 changes: 2 additions & 1 deletion src/torch_fourier_slice/grids/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .central_slice_fftfreq_grid import central_slice_fftfreq_grid
from .central_line_fftfreq_grid import central_line_fftfreq_grid
from .central_slice_fftfreq_grid import central_slice_fftfreq_grid
33 changes: 33 additions & 0 deletions src/torch_fourier_slice/grids/central_line_fftfreq_grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import einops
import torch

from ..dft_utils import fftshift_1d, rfft_shape


def central_line_fftfreq_grid(
image_shape: tuple[int, int],
rfft: bool,
fftshift: bool = False,
device: torch.device | None = None,
) -> torch.Tensor:
# generate 1d grid of DFT sample frequencies, shape (w, 1)
w, = image_shape[-1:]
grid = (
torch.fft.rfftfreq(w, device=device)
if rfft
else torch.fft.fftfreq(w, device=device)
)

# get grid of same shape with all zeros, append as third coordinate
if rfft is True:
zeros = torch.zeros(size=rfft_shape((w,)), dtype=grid.dtype, device=device)
else:
zeros = torch.zeros(size=(w,), dtype=grid.dtype, device=device)
central_slice_grid, _ = einops.pack([zeros, grid], pattern="w *") # (w, 2)

# fftshift if requested
if fftshift is True:
central_slice_grid = einops.rearrange(central_slice_grid, "w freq -> freq w")
central_slice_grid = fftshift_1d(central_slice_grid, rfft=rfft)
central_slice_grid = einops.rearrange(central_slice_grid, "freq w -> w freq")
return central_slice_grid
66 changes: 65 additions & 1 deletion src/torch_fourier_slice/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.nn.functional as F
from torch_grid_utils import fftfreq_grid

from .slice_extraction import extract_central_slices_rfft_3d
from .slice_extraction import extract_central_slices_rfft_2d, extract_central_slices_rfft_3d


def project_3d_to_2d(
Expand Down Expand Up @@ -67,3 +67,67 @@ def project_3d_to_2d(
if pad is True:
projections = projections[..., pad_length:-pad_length, pad_length:-pad_length]
return torch.real(projections)


def project_2d_to_1d(
image: torch.Tensor,
rotation_matrices: torch.Tensor,
pad: bool = True,
fftfreq_max: float | None = None,
) -> torch.Tensor:
"""Project a square image by sampling a central line through its DFT.
Parameters
----------
image: torch.Tensor
`(d, d)` image.
rotation_matrices: torch.Tensor
`(..., 2, 2)` array of rotation matrices for extraction of `lines`.
Rotation matrices left-multiply column vectors containing xy coordinates.
pad: bool
Whether to pad the volume 2x with zeros to increase sampling rate in the DFT.
fftfreq_max: float | None
Maximum frequency (cycles per pixel) included in the projection.
Returns
-------
projections: torch.Tensor
`(..., d)` array of projected lines.
"""
# padding
if pad is True:
pad_length = image.shape[-1] // 2
image = F.pad(image, pad=[pad_length] * 4, mode='constant', value=0)

# premultiply by sinc2
grid = fftfreq_grid(
image_shape=image.shape,
rfft=False,
fftshift=True,
norm=True,
device=image.device
)
image = image * torch.sinc(grid) ** 2

# calculate DFT
dft = torch.fft.fftshift(image, dim=(-2, -1)) # image center to array origin
dft = torch.fft.rfftn(dft, dim=(-2, -1))
dft = torch.fft.fftshift(dft, dim=(-2,)) # actual fftshift of 2D rfft

# make projections by taking central slices
projections = extract_central_slices_rfft_2d(
image_rfft=dft,
image_shape=image.shape,
rotation_matrices=rotation_matrices,
fftfreq_max=fftfreq_max
) # (..., w) rfft stack

# transform back to real space
# not needed for 1D: torch.fft.ifftshift(projections, dim=(-2,))
projections = torch.fft.irfftn(projections, dim=(-1))
projections = torch.fft.ifftshift(projections, dim=(-1)) # recenter line in real space

# unpad if required
if pad is True:
projections = projections[..., pad_length:-pad_length]
return torch.real(projections)
1 change: 1 addition & 0 deletions src/torch_fourier_slice/slice_extraction/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from ._extract_central_slices_rfft_2d import extract_central_slices_rfft_2d
from ._extract_central_slices_rfft_3d import extract_central_slices_rfft_3d
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import einops
import torch
from torch_image_lerp import sample_image_2d

from ..dft_utils import fftfreq_to_dft_coordinates
from ..grids.central_line_fftfreq_grid import central_line_fftfreq_grid


def extract_central_slices_rfft_2d(
image_rfft: torch.Tensor,
image_shape: tuple[int, int],
rotation_matrices: torch.Tensor, # (..., 2, 2)
fftfreq_max: float | None = None,
) -> torch.Tensor:
"""Extract central slice from an fftshifted rfft."""
# generate grid of DFT sample frequencies for a central slice spanning the x-plane
freq_grid = central_line_fftfreq_grid(
image_shape=image_shape,
rfft=True,
fftshift=True,
device=image_rfft.device,
) # (w, 2) yx coords

# keep track of some shapes
stack_shape = tuple(rotation_matrices.shape[:-2])
rfft_shape = (freq_grid.shape[-2],)
output_shape = (*stack_shape, *rfft_shape)

# get (b, 2, 1) array of yx coordinates to rotate
if fftfreq_max is not None:
freq_grid_mask = freq_grid <= fftfreq_max
valid_coords = freq_grid[freq_grid_mask, ...]
else:
valid_coords = freq_grid
valid_coords = einops.rearrange(valid_coords, "b yx -> b yx 1")

# rotation matrices rotate xyz coordinates, make them rotate zyx coordinates
# xyz:
# [a b c] [x] [ax + by + cz]
# [d e f] [y] = [dx + ey + fz]
# [g h i] [z] [gx + hy + iz]
#
# zyx:
# [i h g] [z] [gx + hy + iz]
# [f e d] [y] = [dx + ey + fz]
# [c b a] [x] [ax + by + cz]
rotation_matrices = torch.flip(rotation_matrices, dims=(-2, -1))

# add extra dim to rotation matrices for broadcasting
rotation_matrices = einops.rearrange(rotation_matrices, "... i j -> ... 1 i j")

# rotate all valid coordinates by each rotation matrix
rotated_coords = rotation_matrices @ valid_coords # (..., b, yx, 1)

# remove last dim of size 1
rotated_coords = einops.rearrange(rotated_coords, "... b yx 1 -> ... b yx")

# flip coordinates that ended up in redundant half transform after rotation
conjugate_mask = rotated_coords[..., 1] < 0
rotated_coords[conjugate_mask, ...] *= -1

# convert frequencies to array coordinates in fftshifted DFT
rotated_coords = fftfreq_to_dft_coordinates(
frequencies=rotated_coords, image_shape=image_shape, rfft=True
) # (...) rfft
samples = sample_image_2d(image=image_rfft, coordinates=rotated_coords)

# take complex conjugate of values from redundant half transform
samples[conjugate_mask] = torch.conj(samples[conjugate_mask])

# insert samples back into DFTs
projection_image_dfts = torch.zeros(
output_shape, device=image_rfft.device, dtype=image_rfft.dtype
)
if fftfreq_max is None:
freq_grid_mask = torch.ones(
size=rfft_shape, dtype=torch.bool, device=image_rfft.device
)

projection_image_dfts[..., freq_grid_mask] = samples

return projection_image_dfts
20 changes: 19 additions & 1 deletion tests/test_torch_fourier_slice.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import torch

from torch_fourier_slice import project_3d_to_2d, backproject_2d_to_3d
from torch_fourier_slice import project_3d_to_2d, project_2d_to_1d, backproject_2d_to_3d
from torch_fourier_shell_correlation import fsc
from scipy.stats import special_ortho_group

Expand All @@ -25,6 +25,24 @@ def test_project_3d_to_2d_rotation_center():
assert (i, j) == (16, 16)


def test_project_2d_to_1d_rotation_center():
# rotation center should be at position of DC in DFT
image = torch.zeros((32, 32))
image[16, 16] = 1

# make projections
rotation_matrices = torch.tensor(special_ortho_group.rvs(dim=2, size=100)).float()
projections = project_2d_to_1d(
image=image,
rotation_matrices=rotation_matrices,
)

# check max is always at (16), implying point (16) never moves
for image in projections:
i = torch.argmax(image)
assert i == 16


def test_3d_2d_projection_backprojection_cycle(cube):
# make projections
rotation_matrices = torch.tensor(special_ortho_group.rvs(dim=3, size=1500)).float()
Expand Down

0 comments on commit bf1588b

Please sign in to comment.