-
Notifications
You must be signed in to change notification settings - Fork 4
/
projection_backprojection_cycle.py
47 lines (37 loc) · 1.21 KB
/
projection_backprojection_cycle.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
"""A simple projection/backprojection cycle implementation."""
import mrcfile
import torch
from scipy.stats import special_ortho_group
from torch_fourier_slice import project_3d_to_2d, backproject_2d_to_3d
N_IMAGES = 1000
torch.manual_seed(42)
# load a volume and normalise
volume = torch.tensor(mrcfile.read('/Users/burta2/data/4v6x_bin4.mrc'))
volume -= torch.mean(volume)
volume /= torch.std(volume)
# rotation matrices for projection (operate on xyz column vectors)
rotations = torch.tensor(
special_ortho_group.rvs(dim=3, size=N_IMAGES, random_state=42)
).float()
# make projections
projections = project_3d_to_2d(
volume,
rotation_matrices=rotations,
pad=True,
) # (b, h, w)
# reconstruct volume from projections
reconstruction = backproject_2d_to_3d(
images=projections,
rotation_matrices=rotations,
pad=True,
)
reconstruction -= torch.mean(reconstruction)
reconstruction = reconstruction / torch.std(reconstruction)
# visualise
import napari
viewer = napari.Viewer()
viewer.add_image(projections.numpy(), name='projections')
viewer = napari.Viewer(ndisplay=3)
viewer.add_image(volume.numpy(), name='ground truth')
viewer.add_image(reconstruction.numpy(), name='reconstruction')
napari.run()