Skip to content

Commit

Permalink
use cryotypes for projection model backend (#18)
Browse files Browse the repository at this point in the history
* feat: initialize cryotypes projection model in usage example

* feat: rely everywhere on cryotypes ProjectionModel

* refactor: add utility functions for matrix generation to facilitate projection_model to matrix conversion

* refactor: remoeve image stretch function

* fix: set cryotypes version to 0.2
  • Loading branch information
McHaillet authored Nov 14, 2024
1 parent dda6553 commit ecaa1bc
Show file tree
Hide file tree
Showing 12 changed files with 282 additions and 215 deletions.
37 changes: 21 additions & 16 deletions examples/usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import numpy as np
import pooch
import torch
from cryotypes.projectionmodel import ProjectionModel
from cryotypes.projectionmodel import ProjectionModelDataLabels as PMDL
from torch_fourier_rescale import fourier_rescale_2d
from torch_subpixel_crop import subpixel_crop_2d

Expand All @@ -26,10 +28,10 @@

IMAGE_FILE = Path(GOODBOY.fetch("tomo200528_107.st", progressbar=True))
with open(Path(GOODBOY.fetch("tomo200528_107.rawtlt"))) as f:
STAGE_TILT_ANGLE_PRIORS = torch.tensor([float(x) for x in f.readlines()])
STAGE_TILT_ANGLE_PRIORS = [float(x) for x in f.readlines()]
IMAGE_PIXEL_SIZE = 1.724
# this angle is assumed to be a clockwise forward rotation after projecting the sample
TILT_AXIS_ANGLE_PRIOR = torch.tensor(-88.7)
TILT_AXIS_ANGLE_PRIOR = -88.7
ALIGNMENT_PIXEL_SIZE = IMAGE_PIXEL_SIZE * 8
ALIGN_Z = int(1600 / ALIGNMENT_PIXEL_SIZE) # number is in A
RECON_Z = int(2400 / ALIGNMENT_PIXEL_SIZE)
Expand All @@ -41,6 +43,19 @@
# Set the device for running
DEVICE = "cuda:0"

# Initialize the projection-model prior
projection_model_prior = ProjectionModel(
{
PMDL.ROTATION_Z: TILT_AXIS_ANGLE_PRIOR,
PMDL.ROTATION_Y: STAGE_TILT_ANGLE_PRIORS,
PMDL.ROTATION_X: 0.0,
PMDL.SHIFT_X: 0.0,
PMDL.SHIFT_Y: 0.0,
PMDL.EXPERIMENT_ID: IMAGE_FILE.stem,
PMDL.PIXEL_SPACING: ALIGNMENT_PIXEL_SIZE,
PMDL.SOURCE: IMAGE_FILE.name,
}
)

tilt_series = torch.as_tensor(mrcfile.read(IMAGE_FILE))

Expand Down Expand Up @@ -69,25 +84,21 @@
size = min(h, w)

# Move all the input to the device
tilt_angles, tilt_axis_angles, shifts = tilt_series_alignment(
projection_model_optimized = tilt_series_alignment(
tilt_series.to(DEVICE),
STAGE_TILT_ANGLE_PRIORS,
TILT_AXIS_ANGLE_PRIOR,
projection_model_prior,
ALIGN_Z,
find_tilt_angle_offset=False,
)

final, aligned_ts = filtered_back_projection_3d(
final = filtered_back_projection_3d(
tilt_series,
(RECON_Z, size, size),
tilt_angles,
tilt_axis_angles,
shifts,
projection_model_optimized,
weighting=WEIGHTING,
object_diameter=OBJECT_DIAMETER,
)
final = final.to("cpu")
aligned_ts = aligned_ts.to("cpu")

OUTPUT_DIR.mkdir(exist_ok=True)
mrcfile.write(
Expand All @@ -96,9 +107,3 @@
voxel_size=ALIGNMENT_PIXEL_SIZE,
overwrite=True,
)
mrcfile.write(
OUTPUT_DIR.joinpath(IMAGE_FILE.with_suffix(".ali").name),
aligned_ts.detach().numpy().astype(np.float32),
voxel_size=ALIGNMENT_PIXEL_SIZE,
overwrite=True,
)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ dependencies = [
"torch-cubic-spline-grids",
"torch-fourier-shift",
"torch-image-lerp",
"cryotypes == 0.2",
"einops",
"numpy",
"scipy",
Expand Down
3 changes: 1 addition & 2 deletions src/tttsa/affine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""2D and 3D affine transform functionality."""

from .affine_transform import affine_transform_2d, affine_transform_3d, stretch_image
from .affine_transform import affine_transform_2d, affine_transform_3d

__all__ = [
"affine_transform_2d",
"affine_transform_3d",
"stretch_image",
]
26 changes: 1 addition & 25 deletions src/tttsa/affine/affine_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,7 @@
import torch.nn.functional as F
from torch_grid_utils import coordinate_grid

from tttsa.transformations import R_2d, T_2d
from tttsa.utils import array_to_grid_sample, dft_center, homogenise_coordinates


def stretch_image(
image: torch.Tensor,
stretch: torch.Tensor | float,
tilt_axis_angle: torch.Tensor | float,
) -> torch.Tensor:
"""Utility function for stretching an image on the tilt axis."""
image_center = dft_center(image.shape, rfft=False, fftshifted=True)
# construct matrix
s0 = T_2d(-image_center)
r_forward = R_2d(tilt_axis_angle, yx=True)
r_backward = torch.linalg.inv(r_forward)
m_stretch = torch.eye(3)
m_stretch[1, 1] = stretch # this is a shear matrix
s1 = T_2d(image_center)
m_affine = s1 @ r_forward @ m_stretch @ r_backward @ s0
# transform image
stretched = affine_transform_2d(
image,
m_affine,
)
return stretched
from tttsa.utils import array_to_grid_sample, homogenise_coordinates


def affine_transform_2d(
Expand Down
54 changes: 28 additions & 26 deletions src/tttsa/back_projection/filtered_back_projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,25 @@
import einops
import torch
import torch.nn.functional as F
from cryotypes.projectionmodel import ProjectionModel
from cryotypes.projectionmodel import ProjectionModelDataLabels as PMDL
from torch_grid_utils import coordinate_grid

from tttsa.affine import affine_transform_2d
from tttsa.transformations import R_2d, Ry, T, T_2d
from tttsa.utils import array_to_grid_sample, dft_center, homogenise_coordinates
from tttsa.transformations import (
projection_model_to_backproject_matrix,
projection_model_to_tsa_matrix,
)
from tttsa.utils import array_to_grid_sample, homogenise_coordinates

# update shift
PMDL.SHIFT = [PMDL.SHIFT_Y, PMDL.SHIFT_X]


def filtered_back_projection_3d(
tilt_series: torch.Tensor,
tomogram_dimensions: Tuple[int, int, int],
tilt_angles: torch.Tensor,
tilt_axis_angles: torch.Tensor,
shifts: torch.Tensor,
projection_model: ProjectionModel,
weighting: str = "exact",
object_diameter: float | None = None,
) -> torch.Tensor:
Expand All @@ -43,19 +49,12 @@ def filtered_back_projection_3d(
n_tilts, h, w = tilt_series.shape # for simplicity assume square images
tilt_image_dimensions = (h, w)
transformed_image_dimensions = tomogram_dimensions[-2:]
tomogram_center = dft_center(tomogram_dimensions, rfft=False, fftshifted=True)
tilt_image_center = dft_center(tilt_image_dimensions, rfft=False, fftshifted=True)
transformed_image_center = dft_center(
transformed_image_dimensions, rfft=False, fftshifted=True
)
_, filter_size = transformed_image_dimensions

# generate the 2d alignment affine matrix
s0 = T_2d(-tilt_image_center)
r0 = R_2d(tilt_axis_angles, yx=True)
s1 = T_2d(-shifts)
s2 = T_2d(transformed_image_center)
M = torch.linalg.inv(s2 @ s1 @ r0 @ s0).to(device)
M = projection_model_to_tsa_matrix(
projection_model, tilt_image_dimensions, transformed_image_dimensions
).to(device)

aligned_ts = affine_transform_2d(
tilt_series,
Expand All @@ -69,7 +68,7 @@ def filtered_back_projection_3d(
raise ValueError(
"Calculation of exact weighting requires an object " "diameter."
)
if len(tilt_angles) == 1:
if n_tilts == 1:
# set explicitly as tensor to ensure correct typing
filters = torch.tensor(1.0, device=device)
else: # slice_width could be provided as a function argument it can be
Expand All @@ -83,6 +82,7 @@ def filtered_back_projection_3d(
/ filter_size,
"q -> 1 1 q",
)
tilt_angles = torch.as_tensor(projection_model[PMDL.ROTATION_Y])
sampling = torch.sin(
torch.deg2rad(
torch.abs(einops.rearrange(tilt_angles, "n -> n 1") - tilt_angles)
Expand Down Expand Up @@ -124,15 +124,17 @@ def filtered_back_projection_3d(
if len(weighted.shape) == 2: # rfftn gets rid of batch dimension: add it back
weighted = einops.rearrange(weighted, "h w -> 1 h w")

# create recon from weighted-aligned ts
s0 = T(-tomogram_center)
r0 = Ry(tilt_angles, zyx=True)
s1 = T(tomogram_center)
# This would actually be a double linalg.inv. First for the inverse of the
# forward projection alignment model. The second for the affine transform.
# It could be more logical to use affine_transform_3d, but it requires
# recalculation of the grid for every iteration.
M = einops.rearrange(s1 @ r0 @ s0, "... i j -> ... 1 1 i j").to(device)
# We need to lingalg.inv the matrix as the affine transform is done inside
# this function. It could be more logical to use affine_transform_3d (and do
# inversion inside) but it requires recalculation of the grid for every iteration.
M = einops.rearrange(
torch.linalg.inv(
projection_model_to_backproject_matrix(
projection_model, tomogram_dimensions
)
),
"... i j -> ... 1 1 i j",
).to(device)

reconstruction = torch.zeros(
tomogram_dimensions, dtype=torch.float32, device=device
Expand All @@ -156,4 +158,4 @@ def filtered_back_projection_3d(
mode="bilinear",
)
)
return reconstruction, aligned_ts
return reconstruction
57 changes: 29 additions & 28 deletions src/tttsa/coarse_align.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Coarse tilt-series alignment functions, also with stretching."""

import einops
import torch

from .affine import stretch_image
from .affine import affine_transform_2d
from .alignment import find_image_shift
from .transformations import stretch_matrix


def coarse_align(
Expand All @@ -12,20 +14,25 @@ def coarse_align(
mask: torch.Tensor,
) -> torch.Tensor:
"""Find coarse shifts of images without stretching along tilt axis."""
shifts = torch.zeros((len(tilt_series), 2), dtype=torch.float32)
n_tilts = len(tilt_series)
shifts = torch.zeros((n_tilts, 2), dtype=torch.float32)
ts_masked = tilt_series * mask
ts_masked -= einops.reduce(ts_masked, "tilt h w -> tilt 1 1", reduction="mean")
ts_masked /= torch.std(ts_masked, dim=(-2, -1), keepdim=True)

# find coarse alignment for negative tilts
current_shift = torch.zeros(2)
for i in range(reference_tilt_id, 0, -1):
shift = find_image_shift(tilt_series[i] * mask, tilt_series[i - 1] * mask)
shift = find_image_shift(ts_masked[i], ts_masked[i - 1])
current_shift += shift
shifts[i - 1] = current_shift

# find coarse alignment positive tilts
current_shift = torch.zeros(2)
for i in range(reference_tilt_id, tilt_series.shape[0] - 1, 1):
for i in range(reference_tilt_id, n_tilts - 1, 1):
shift = find_image_shift(
tilt_series[i] * mask,
tilt_series[i + 1] * mask,
ts_masked[i],
ts_masked[i + 1],
)
current_shift += shift
shifts[i + 1] = current_shift
Expand All @@ -40,21 +47,20 @@ def stretch_align(
tilt_axis_angles: torch.Tensor,
) -> torch.Tensor:
"""Find coarse shifts of images while stretching each pair along the tilt axis."""
shifts = torch.zeros((len(tilt_series), 2), dtype=torch.float32)
n_tilts, h, w = tilt_series.shape
tilt_image_dimensions = (h, w)
shifts = torch.zeros((n_tilts, 2), dtype=torch.float32)
cos_ta = torch.cos(torch.deg2rad(tilt_angles))

# find coarse alignment for negative tilts
current_shift = torch.zeros(2)
for i in range(reference_tilt_id, 0, -1):
scale_factor = torch.cos(torch.deg2rad(tilt_angles[i : i + 1])) / torch.cos(
torch.deg2rad(tilt_angles[i - 1 : i])
)
stretched = (
stretch_image(
tilt_series[i - 1],
scale_factor,
tilt_axis_angles[i - 1],
)
* mask
M = stretch_matrix(
tilt_image_dimensions,
tilt_axis_angles[i - 1],
scale_factor=cos_ta[i : i + 1] / cos_ta[i - 1 : i],
)
stretched = affine_transform_2d(tilt_series[i - 1], M) * mask
stretched = (stretched - stretched.mean()) / stretched.std()
raw = tilt_series[i] * mask
raw = (raw - raw.mean()) / raw.std()
Expand All @@ -63,18 +69,13 @@ def stretch_align(
shifts[i - 1] = current_shift
# find coarse alignment positive tilts
current_shift = torch.zeros(2)
for i in range(reference_tilt_id, tilt_series.shape[0] - 1, 1):
scale_factor = torch.cos(torch.deg2rad(tilt_angles[i : i + 1])) / torch.cos(
torch.deg2rad(tilt_angles[i + 1 : i + 2])
)
stretched = (
stretch_image(
tilt_series[i + 1],
scale_factor,
tilt_axis_angles[i + 1],
)
* mask
for i in range(reference_tilt_id, n_tilts - 1, 1):
M = stretch_matrix(
tilt_image_dimensions,
tilt_axis_angles[i + 1],
scale_factor=cos_ta[i : i + 1] / cos_ta[i + 1 : i + 2],
)
stretched = affine_transform_2d(tilt_series[i + 1], M) * mask
stretched = (stretched - stretched.mean()) / stretched.std()
raw = tilt_series[i] * mask
raw = (raw - raw.mean()) / raw.std()
Expand Down
Loading

0 comments on commit ecaa1bc

Please sign in to comment.