Skip to content

Commit

Permalink
torch-fourier-slice for common line extraction (#19)
Browse files Browse the repository at this point in the history
* feat: add fourier slice extraction for common lines

* fix: stacking and add comment to explain logic

* update: fourier slice to teamtomo main

* refactor: remove old common lines code

* fix to new torch-fourier-slice versions
  • Loading branch information
McHaillet authored Nov 19, 2024
1 parent f278344 commit 68a5ccb
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 47 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ dependencies = [
"torch-cubic-spline-grids",
"torch-fourier-shift",
"torch-image-lerp",
"torch-fourier-slice >= 0.0.5",
"cryotypes == 0.2",
"einops",
"numpy",
Expand Down
27 changes: 19 additions & 8 deletions src/tttsa/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import einops
import torch
from torch_cubic_spline_grids import CubicBSplineGrid1d
from torch_fourier_slice import project_2d_to_1d

from .affine import affine_transform_2d
from .projection import common_lines_projection
from .transformations import T_2d, stretch_matrix
from .transformations import R_2d, T_2d, stretch_matrix


def stretch_loss(
Expand Down Expand Up @@ -69,10 +69,10 @@ def optimize_tilt_axis_angle(
coarse_aligned_masked = aligned_ts * coarse_alignment_mask

# generate a weighting for the common line ROI by projecting the mask
mask_weights = common_lines_projection(
einops.rearrange(coarse_alignment_mask, "h w -> 1 h w"),
0.0, # angle does not matter
)
mask_weights = project_2d_to_1d(
coarse_alignment_mask,
torch.eye(2).to(coarse_alignment_mask.device), # angle does not matter
).squeeze() # remove starting empty dimension
mask_weights /= mask_weights.max() # normalise to 0 and 1

# optimize tilt axis angle
Expand All @@ -94,8 +94,19 @@ def optimize_tilt_axis_angle(
)

def closure() -> torch.Tensor:
tilt_axis_angles = tilt_axis_grid(interpolation_points)
projections = common_lines_projection(coarse_aligned_masked, tilt_axis_angles)
# The common line is the projection perpendicular to the aligned tilt-axis (
# aligned with the y-axis), hence add 90 degrees to project along the x-axis.
M = R_2d(tilt_axis_grid(interpolation_points) + 90, yx=False)[:, :2, :2]
projections = einops.rearrange(
[
project_2d_to_1d(
coarse_aligned_masked[(i,)],
M[(i,)].to(coarse_aligned_masked.device),
).squeeze() # squeeze as we only calculate one projection
for i in range(len(coarse_aligned_masked))
],
"n w -> n w",
)
projections = projections - einops.reduce(
projections, "tilt w -> tilt 1", reduction="mean"
)
Expand Down
3 changes: 1 addition & 2 deletions src/tttsa/projection/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
"""Projection of images and volumes."""

from .project_real import common_lines_projection, tomogram_reprojection
from .project_real import tomogram_reprojection

__all__ = [
"common_lines_projection",
"tomogram_reprojection",
]
38 changes: 2 additions & 36 deletions src/tttsa/projection/project_real.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,50 +9,15 @@
from torch_grid_utils import coordinate_grid
from torch_image_lerp import insert_into_image_2d

from tttsa.affine import affine_transform_2d
from tttsa.transformations import (
R_2d,
T_2d,
projection_model_to_projection_matrix,
)
from tttsa.utils import dft_center, homogenise_coordinates
from tttsa.utils import homogenise_coordinates

# update shift
PMDL.SHIFT = [PMDL.SHIFT_Y, PMDL.SHIFT_X]


def common_lines_projection(
images: torch.Tensor,
tilt_axis_angles: torch.Tensor,
# this might as well takes shifts
) -> torch.Tensor:
"""Predict a projection from an intermediate reconstruction.
For now only assumes to project with a single matrix, but should also work for
sets of matrices.
"""
device = images.device
image_dimensions = images.shape[-2:]

# TODO pad image if not square

image_center = dft_center(image_dimensions, rfft=False, fftshifted=True)

# time for real space projection
s0 = T_2d(-image_center)
r0 = R_2d(tilt_axis_angles, yx=True)
s1 = T_2d(image_center)
# invert because the tilt axis angle is forward in the sample projection model
M = torch.linalg.inv(s1 @ r0 @ s0).to(device)

rotated = affine_transform_2d(
images,
M,
)
projections = rotated.mean(axis=-1).squeeze()
return projections


def tomogram_reprojection(
tomogram: torch.Tensor,
tilt_image_dimensions: Tuple[int, int],
Expand All @@ -77,6 +42,7 @@ def tomogram_reprojection(
grid = einops.rearrange(grid, "d h w coords -> d h w coords 1")
grid = Mproj @ grid
grid = einops.rearrange(grid, "... d h w coords 1 -> ... d h w coords")

projection, weights = insert_into_image_2d(
tomogram.view(-1), # flatten
grid.view(-1, 2),
Expand Down
2 changes: 1 addition & 1 deletion src/tttsa/tttsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def tilt_series_alignment(
).numpy()

start_taa_grid_points = 1 # taa = tilt-axis angle
pm_taa_grid_points = 3 # pm = projection matching
pm_taa_grid_points = 1 # pm = projection matching

console.print(
f"=== Optimizing tilt-axis angle with {start_taa_grid_points} grid point."
Expand Down

0 comments on commit 68a5ccb

Please sign in to comment.