diff --git a/pyproject.toml b/pyproject.toml index 39cb1c0..4ea8135 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ dependencies = [ "torch-cubic-spline-grids", "torch-fourier-shift", "torch-image-lerp", + "torch-fourier-slice >= 0.0.5", "cryotypes == 0.2", "einops", "numpy", diff --git a/src/tttsa/optimizers.py b/src/tttsa/optimizers.py index f76d96f..6045c8d 100644 --- a/src/tttsa/optimizers.py +++ b/src/tttsa/optimizers.py @@ -3,10 +3,10 @@ import einops import torch from torch_cubic_spline_grids import CubicBSplineGrid1d +from torch_fourier_slice import project_2d_to_1d from .affine import affine_transform_2d -from .projection import common_lines_projection -from .transformations import T_2d, stretch_matrix +from .transformations import R_2d, T_2d, stretch_matrix def stretch_loss( @@ -69,10 +69,10 @@ def optimize_tilt_axis_angle( coarse_aligned_masked = aligned_ts * coarse_alignment_mask # generate a weighting for the common line ROI by projecting the mask - mask_weights = common_lines_projection( - einops.rearrange(coarse_alignment_mask, "h w -> 1 h w"), - 0.0, # angle does not matter - ) + mask_weights = project_2d_to_1d( + coarse_alignment_mask, + torch.eye(2).to(coarse_alignment_mask.device), # angle does not matter + ).squeeze() # remove starting empty dimension mask_weights /= mask_weights.max() # normalise to 0 and 1 # optimize tilt axis angle @@ -94,8 +94,19 @@ def optimize_tilt_axis_angle( ) def closure() -> torch.Tensor: - tilt_axis_angles = tilt_axis_grid(interpolation_points) - projections = common_lines_projection(coarse_aligned_masked, tilt_axis_angles) + # The common line is the projection perpendicular to the aligned tilt-axis ( + # aligned with the y-axis), hence add 90 degrees to project along the x-axis. + M = R_2d(tilt_axis_grid(interpolation_points) + 90, yx=False)[:, :2, :2] + projections = einops.rearrange( + [ + project_2d_to_1d( + coarse_aligned_masked[(i,)], + M[(i,)].to(coarse_aligned_masked.device), + ).squeeze() # squeeze as we only calculate one projection + for i in range(len(coarse_aligned_masked)) + ], + "n w -> n w", + ) projections = projections - einops.reduce( projections, "tilt w -> tilt 1", reduction="mean" ) diff --git a/src/tttsa/projection/__init__.py b/src/tttsa/projection/__init__.py index d054365..e7aef33 100644 --- a/src/tttsa/projection/__init__.py +++ b/src/tttsa/projection/__init__.py @@ -1,8 +1,7 @@ """Projection of images and volumes.""" -from .project_real import common_lines_projection, tomogram_reprojection +from .project_real import tomogram_reprojection __all__ = [ - "common_lines_projection", "tomogram_reprojection", ] diff --git a/src/tttsa/projection/project_real.py b/src/tttsa/projection/project_real.py index b1608c0..349336b 100644 --- a/src/tttsa/projection/project_real.py +++ b/src/tttsa/projection/project_real.py @@ -9,50 +9,15 @@ from torch_grid_utils import coordinate_grid from torch_image_lerp import insert_into_image_2d -from tttsa.affine import affine_transform_2d from tttsa.transformations import ( - R_2d, - T_2d, projection_model_to_projection_matrix, ) -from tttsa.utils import dft_center, homogenise_coordinates +from tttsa.utils import homogenise_coordinates # update shift PMDL.SHIFT = [PMDL.SHIFT_Y, PMDL.SHIFT_X] -def common_lines_projection( - images: torch.Tensor, - tilt_axis_angles: torch.Tensor, - # this might as well takes shifts -) -> torch.Tensor: - """Predict a projection from an intermediate reconstruction. - - For now only assumes to project with a single matrix, but should also work for - sets of matrices. - """ - device = images.device - image_dimensions = images.shape[-2:] - - # TODO pad image if not square - - image_center = dft_center(image_dimensions, rfft=False, fftshifted=True) - - # time for real space projection - s0 = T_2d(-image_center) - r0 = R_2d(tilt_axis_angles, yx=True) - s1 = T_2d(image_center) - # invert because the tilt axis angle is forward in the sample projection model - M = torch.linalg.inv(s1 @ r0 @ s0).to(device) - - rotated = affine_transform_2d( - images, - M, - ) - projections = rotated.mean(axis=-1).squeeze() - return projections - - def tomogram_reprojection( tomogram: torch.Tensor, tilt_image_dimensions: Tuple[int, int], @@ -77,6 +42,7 @@ def tomogram_reprojection( grid = einops.rearrange(grid, "d h w coords -> d h w coords 1") grid = Mproj @ grid grid = einops.rearrange(grid, "... d h w coords 1 -> ... d h w coords") + projection, weights = insert_into_image_2d( tomogram.view(-1), # flatten grid.view(-1, 2), diff --git a/src/tttsa/tttsa.py b/src/tttsa/tttsa.py index 00cd94f..4526683 100644 --- a/src/tttsa/tttsa.py +++ b/src/tttsa/tttsa.py @@ -64,7 +64,7 @@ def tilt_series_alignment( ).numpy() start_taa_grid_points = 1 # taa = tilt-axis angle - pm_taa_grid_points = 3 # pm = projection matching + pm_taa_grid_points = 1 # pm = projection matching console.print( f"=== Optimizing tilt-axis angle with {start_taa_grid_points} grid point."