Skip to content

Commit

Permalink
Build example
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael-T-McCann committed Jun 7, 2024
1 parent 6ad9fd8 commit 997ebf3
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 24 deletions.
2 changes: 1 addition & 1 deletion data
1 change: 1 addition & 0 deletions docs/source/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Computed Tomography
examples/ct_astra_odp_train_foam2
examples/ct_astra_unet_train_foam2
examples/ct_projector_comparison
examples/ct_projector_comparison_3d
examples/ct_multi_cs_tv_admm
examples/ct_multi_tv_admm

Expand Down
2 changes: 2 additions & 0 deletions examples/scripts/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ Computed Tomography
CT Training and Reconstructions with UNet
`ct_projector_comparison.py <ct_projector_comparison.py>`_
X-ray Transform Comparison
`ct_projector_comparison_3d.py <ct_projector_comparison_3d.py>`_
X-ray Transform Comparison in 3D
`ct_multi_cs_tv_admm.py <ct_multi_cs_tv_admm.py>`_
TV-Regularized Sparse-View CT Reconstruction (Multiple Projectors, Common Sinogram)
`ct_multi_tv_admm.py <ct_multi_tv_admm.py>`_
Expand Down
96 changes: 77 additions & 19 deletions examples/scripts/ct_projector_comparison_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
X-ray Transform Comparison in 3D
================================
This example compares SCICO's native 3D X-ray transform algorithm
to that of the ASTRA toolbox.
This example shows how to define a SCICO native 3D X-ray transform using
ASTRA toolbox conventions and vice versa.
"""

import numpy as np
Expand All @@ -35,13 +35,12 @@
x = create_block_phantom(in_shape)
x = jnp.array(x)

diagonal_length = int(jnp.ceil(jnp.sqrt(3) * N))
# use rectangular detector to check whether it is handled correctly
out_shape = (diagonal_length, diagonal_length + 1)
out_shape = (N, N + 1)


"""
Set up SCICO projection
Set up SCICO projection.
"""
num_angles = 7

Expand All @@ -58,7 +57,7 @@


"""
Specify geometry using SCICO conventions and project
Specify geometry using SCICO conventions and project.
"""
num_repeats = 3

Expand All @@ -83,39 +82,61 @@
timer_scico.stop("avg_fwd")
timer_scico.td["avg_fwd"] /= num_repeats

timer_scico.start("first_back")
HTy_scico = H_scico.T @ y_scico
timer_scico.stop("first_back")

timer_scico.start("avg_back")
for _ in range(num_repeats):
HTy_scico = H_scico.T @ y_scico
jax.block_until_ready(HTy_scico)
timer_scico.stop("avg_back")
timer_scico.td["avg_back"] /= num_repeats


"""
Convert SCICO geometry to ASTRA and project
Convert SCICO geometry to ASTRA and project.
"""

P_to_astra_vectors = scico.linop.xray.P_to_vectors(in_shape, P, out_shape)

timer_astra = Timer()
timer_astra.start("astra_init")
timer_astra.start("init")
H_astra_from_scico = astra.XRayTransform3D(
input_shape=in_shape, det_count=out_shape, vectors=P_to_astra_vectors
)
timer_astra.stop("astra_init")
timer_astra.stop("init")

timer_astra.start("first_fwd")
y_astra_from_scico = H_astra_from_scico @ x
jax.block_until_ready(y_scico)
timer_astra.stop("first_fwd")

timer_astra.start("first_fwd")
y_astra_from_scico = H_scico @ x
y_astra_from_scico = H_astra_from_scico @ x
timer_astra.stop("first_fwd")

timer_astra.start("avg_fwd")
for _ in range(num_repeats):
y_astra_from_scico = H_scico @ x
y_astra_from_scico = H_astra_from_scico @ x
jax.block_until_ready(y_astra_from_scico)
timer_astra.stop("avg_fwd")
timer_astra.td["avg_fwd"] /= num_repeats

timer_astra.start("first_back")
HTy_astra_from_scico = H_astra_from_scico.T @ y_astra_from_scico
timer_astra.stop("first_back")

timer_astra.start("avg_back")
for _ in range(num_repeats):
HTy_astra_from_scico = H_astra_from_scico.T @ y_astra_from_scico
jax.block_until_ready(HTy_astra_from_scico)
timer_astra.stop("avg_back")
timer_astra.td["avg_back"] /= num_repeats


"""
Specify geometry with ASTRA conventions and project
Specify geometry with ASTRA conventions and project.
"""

angles = np.linspace(0, np.pi, num_angles) # evenly spaced projection angles
Expand All @@ -126,25 +147,42 @@

y_astra = H_astra @ x

HTy_astra = H_astra.T @ y_astra

"""
Convert ASTRA geometry to SCICO and project
Convert ASTRA geometry to SCICO and project.
"""

P_from_astra = scico.linop.xray.astra_to_scico(H_astra.vol_geom, H_astra.proj_geom)
H_scico_from_astra = XRayTransform(Parallel3dProjector(in_shape, P_from_astra, out_shape))

y_scico_from_astra = H_scico_from_astra @ x

HTy_scico_from_astra = H_scico_from_astra.T @ y_scico_from_astra

"""
Print timing results.
"""
print(f"init astra {timer_astra.td['init']:.2e} s")
print(f"init scico {timer_scico.td['init']:.2e} s")
print("")
for tstr in ("first", "avg"):
for dstr in ("fwd", "back"):
for timer, pstr in zip((timer_astra, timer_scico), ("astra", "scico")):
print(f"{tstr:5s} {dstr:4s} {pstr} {timer.td[tstr + '_' + dstr]:.2e} s")
print()


"""
Show projections.
"""
fig, ax = plot.subplots(nrows=3, ncols=2, figsize=(8, 6))
plot.imview(y_scico[0], title="SCICO projections", cbar=None, fig=fig, ax=ax[0, 0])
plot.imview(y_scico[2], cbar=None, fig=fig, ax=ax[1, 0])
plot.imview(y_scico[4], cbar=None, fig=fig, ax=ax[2, 0])
plot.imview(y_astra_from_scico[0], title="ASTRA projections", cbar=None, fig=fig, ax=ax[0, 1])
plot.imview(y_astra_from_scico[2], cbar=None, fig=fig, ax=ax[1, 1])
plot.imview(y_astra_from_scico[4], cbar=None, fig=fig, ax=ax[2, 1])
plot.imview(y_astra_from_scico[:, 0], title="ASTRA projections", cbar=None, fig=fig, ax=ax[0, 1])
plot.imview(y_astra_from_scico[:, 2], cbar=None, fig=fig, ax=ax[1, 1])
plot.imview(y_astra_from_scico[:, 4], cbar=None, fig=fig, ax=ax[2, 1])
fig.suptitle("Using SCICO conventions")
fig.tight_layout()
fig.show()
Expand All @@ -153,12 +191,32 @@
plot.imview(y_scico_from_astra[0], title="SCICO projections", cbar=None, fig=fig, ax=ax[0, 0])
plot.imview(y_scico_from_astra[2], cbar=None, fig=fig, ax=ax[1, 0])
plot.imview(y_scico_from_astra[4], cbar=None, fig=fig, ax=ax[2, 0])
plot.imview(y_astra[0], title="ASTRA projections", cbar=None, fig=fig, ax=ax[0, 1])
plot.imview(y_astra[2], cbar=None, fig=fig, ax=ax[1, 1])
plot.imview(y_astra[4], cbar=None, fig=fig, ax=ax[2, 1])
plot.imview(y_astra[:, 0], title="ASTRA projections", cbar=None, fig=fig, ax=ax[0, 1])
plot.imview(y_astra[:, 2], cbar=None, fig=fig, ax=ax[1, 1])
plot.imview(y_astra[:, 4], cbar=None, fig=fig, ax=ax[2, 1])
fig.suptitle("Using ASTRA conventions")
fig.tight_layout()
fig.show()

"""
Show back projections.
"""
fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(8, 6))
plot.imview(HTy_scico[N // 2], title="SCICO back projection", cbar=None, fig=fig, ax=ax[0])
plot.imview(
HTy_astra_from_scico[N // 2], title="ASTRA back projection", cbar=None, fig=fig, ax=ax[1]
)
fig.suptitle("Using SCICO conventions")
fig.tight_layout()
fig.show()

fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(8, 6))
plot.imview(
HTy_scico_from_astra[N // 2], title="SCICO back projection", cbar=None, fig=fig, ax=ax[0]
)
plot.imview(HTy_astra[N // 2], title="ASTRA back projection", cbar=None, fig=fig, ax=ax[1])
fig.suptitle("Using ASTRA conventions")
fig.tight_layout()
fig.show()

input("\nWaiting for input to close figures and exit")
1 change: 1 addition & 0 deletions examples/scripts/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Computed Tomography
- ct_astra_odp_train_foam2.py
- ct_astra_unet_train_foam2.py
- ct_projector_comparison.py
- ct_projector_comparison_3d.py
- ct_multi_cs_tv_admm.py
- ct_multi_tv_admm.py

Expand Down
8 changes: 4 additions & 4 deletions scico/linop/xray/_xray.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,8 @@ def P_to_vectors(in_shape: Shape, P: ArrayLike, det_shape: Shape) -> ArrayLike:
y_center = np.array(det_shape) / 2
x_center = np.einsum("...mn,n->...m", P[..., :3], np.array(in_shape) / 2) + P[..., 3]
d = np.einsum("...mn,...m->...n", P[..., :3], y_center - x_center) # (V, 2, 3) x (V, 2)
u = P[:, 1, :3]
v = P[:, 0, :3]
u = -P[:, 1, :3]
v = -P[:, 0, :3]
vectors = np.concatenate((ray, d, u, v), axis=1) # (v, 12)
return vectors

Expand All @@ -295,11 +295,11 @@ def astra_to_scico(vol_geom, proj_geom):
"""
Convert ASTRA volume and projection geometry into a SCICO X-ray projection matrix.
"""
in_shape = (vol_geom["GridColCount"], vol_geom["GridRowCount"], vol_geom["GridSliceCount"])
in_shape = (vol_geom["GridSliceCount"], vol_geom["GridRowCount"], vol_geom["GridColCount"])
det_shape = (proj_geom["DetectorRowCount"], proj_geom["DetectorColCount"])
vectors = proj_geom["Vectors"]
_, d, u, v = vectors[:, 0:3], vectors[:, 3:6], vectors[:, 6:9], vectors[:, 9:12]
P = np.stack((v, u), axis=1)
P = -np.stack((v, u), axis=1)
center_diff = np.einsum("...mn,...n->...m", P, d) # y_center - x_center
y_center = np.array(det_shape) / 2
Px_center_t = -(center_diff - y_center)
Expand Down

0 comments on commit 997ebf3

Please sign in to comment.