Skip to content

Commit

Permalink
Refactor, improve docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael-T-McCann committed Jun 14, 2024
1 parent 99af54a commit c80f3b8
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 80 deletions.
33 changes: 15 additions & 18 deletions examples/scripts/ct_projector_comparison_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,18 @@
from scico import plot
from scico.examples import create_block_phantom
from scico.linop import Parallel3dProjector, XRayTransform
from scico.util import Timer
from scipy.spatial.transform import Rotation
from scico.util import ContextTimer, Timer

"""
Create a ground truth image and set detector dimensions.
"""
N = 64
# use rectangular volume to check whether it is handled correctly
# use rectangular volume to check whether axes are handled correctly
in_shape = (N + 1, N + 2, N + 3)
x = create_block_phantom(in_shape)
x = jnp.array(x)

# use rectangular detector to check whether it is handled correctly
# use rectangular detector to check whether axes are handled correctly
out_shape = (N, N + 1)


Expand All @@ -44,17 +43,13 @@
"""
num_angles = 3

# make projection matrix: form a rotation matrix and chop off the last row
rot_X = 90.0 - 16.0
rot_Y = np.random.rand(num_angles) * 180
P = jnp.stack([Rotation.from_euler("XY", [rot_X, y], degrees=True).as_matrix() for y in rot_Y])
P = P[:, :2, :]

# add translation
x0 = jnp.array(in_shape) / 2
t = -jnp.tensordot(P, x0, axes=[2, 0]) + jnp.array(out_shape) / 2
P = jnp.concatenate((P, t[..., np.newaxis]), axis=2)

rot_X = 90.0 - 16.0
rot_Y = np.linspace(0, 180, num_angles, endpoint=False)
angles = np.stack(np.broadcast_arrays(rot_X, rot_Y), axis=-1)
matrices = Parallel3dProjector.matrices_from_euler_angles(
in_shape, out_shape, "XY", angles, degrees=True
)

"""
Specify geometry using SCICO conventions and project.
Expand All @@ -63,7 +58,7 @@

timer_scico = Timer()
with ContextTimer(timer_scico, "init"):
H_scico = XRayTransform(Parallel3dProjector(in_shape, P, out_shape))
H_scico = XRayTransform(Parallel3dProjector(in_shape, matrices, out_shape))

with ContextTimer(timer_scico, "first_fwd"):
y_scico = H_scico @ x
Expand Down Expand Up @@ -92,12 +87,14 @@
Convert SCICO geometry to ASTRA and project.
"""

P_to_astra_vectors = scico.linop.xray.P_to_vectors(in_shape, P, out_shape)
vectors_from_scico = scico.linop.xray.astra.convert_from_scico_geometry(
in_shape, matrices, out_shape
)

timer_astra = Timer()
with ContextTimer(timer_astra, "init"):
H_astra_from_scico = astra.XRayTransform3D(
input_shape=in_shape, det_count=out_shape, vectors=P_to_astra_vectors
input_shape=in_shape, det_count=out_shape, vectors=vectors_from_scico
)

with ContextTimer(timer_astra, "first_fwd"):
Expand Down Expand Up @@ -141,7 +138,7 @@
Convert ASTRA geometry to SCICO and project.
"""

P_from_astra = scico.linop.xray.astra_to_scico(H_astra.vol_geom, H_astra.proj_geom)
P_from_astra = scico.linop.xray.astra.convert_to_scico_geometry(H_astra.vol_geom, H_astra.proj_geom)
H_scico_from_astra = XRayTransform(Parallel3dProjector(in_shape, P_from_astra, out_shape))

y_scico_from_astra = H_scico_from_astra @ x
Expand Down
12 changes: 2 additions & 10 deletions scico/linop/xray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,10 @@

import sys

from ._xray import (
P_to_vectors,
Parallel2dProjector,
Parallel3dProjector,
XRayTransform,
astra_to_scico,
)
from ._xray import Parallel2dProjector, Parallel3dProjector, XRayTransform

__all__ = [
"XRayTransform",
"Parallel2dProjector",
"Parallel3dProjector",
"P_to_vectors",
"astra_to_scico",
"XRayTransform",
]
91 changes: 39 additions & 52 deletions scico/linop/xray/_xray.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from jax.typing import ArrayLike

from scico.typing import Shape
from scipy.spatial.transform import Rotation

from .._linop import LinearOperator

Expand Down Expand Up @@ -202,40 +203,40 @@ class Parallel3dProjector:
def __init__(
self,
input_shape: Shape,
P: ArrayLike,
matrices: ArrayLike,
det_shape: Shape,
):
r"""
Args:
input_shape: Shape of input image.
P: (num_angles, 2, 4) array of homogeneous projection matrices.
matrices: (num_angles, 2, 4) array of homogeneous projection matrices.
det_shape: Shape of detector.
"""

self.input_shape = input_shape
self.P = P
self.matrices = matrices
self.det_shape = det_shape
self.output_shape = (len(P), *det_shape)
self.output_shape = (len(matrices), *det_shape)

def project(self, im):
"""Compute X-ray projection."""
return Parallel3dProjector._project(im, self.P, self.det_shape)
return Parallel3dProjector._project(im, self.matrices, self.det_shape)

@staticmethod
@partial(jax.jit, static_argnames="det_shape")
def _project(im: ArrayLike, P: ArrayLike, det_shape: Shape) -> ArrayLike:
def _project(im: ArrayLike, matrices: ArrayLike, det_shape: Shape) -> ArrayLike:
r"""
Args:
im: Input image.
P: (num_angles, 2, 4) array of homogeneous projection matrices.
matrices: (num_angles, 2, 4) array of homogeneous projection matrices.
det_shape: Shape of detector.
"""

x = jnp.mgrid[: im.shape[0], : im.shape[1], : im.shape[2]]
# (v, 2, 3) X (3, x0, x1, x2) + (v, 2) -> (v, 2, x0, x1, x2)
Px = (
jnp.tensordot(P[..., :3], x, axes=[2, 0])
+ P[..., 3, np.newaxis, np.newaxis, np.newaxis]
jnp.tensordot(matrices[..., :3], x, axes=[2, 0])
+ matrices[..., 3, np.newaxis, np.newaxis, np.newaxis]
)

# calculate weight on 4 intersecting pixels
Expand All @@ -250,7 +251,7 @@ def _project(im: ArrayLike, P: ArrayLike, det_shape: Shape) -> ArrayLike:
ll_weight = to_next[:, 0] * (w - to_next[:, 1]) * (1 / w**2)
lr_weight = (w - to_next[:, 0]) * (w - to_next[:, 1]) * (1 / w**2)

num_views = len(P)
num_views = len(matrices)
proj = jnp.zeros((num_views,) + det_shape, dtype=im.dtype)
view_ind = jnp.expand_dims(jnp.arange(num_views), range(1, 4))
proj = proj.at[view_ind, ul_ind[:, 0], ul_ind[:, 1]].add(ul_weight * im, mode="drop")
Expand All @@ -261,49 +262,35 @@ def _project(im: ArrayLike, P: ArrayLike, det_shape: Shape) -> ArrayLike:
)
return proj

@staticmethod
def matrices_from_euler_angles(
input_shape: Shape, output_shape: Shape, seq: str, angles: ArrayLike, degrees: bool = False
):
"""
Create a set of projection matrices from Euler angles.
def P_to_vectors(in_shape: Shape, P: ArrayLike, det_shape: Shape) -> ArrayLike:
"""
Convert SCICO projection matrix into ASTRA vectors.
For 3D arrays,
in Astra, the dimensions go (slices, rows, columns) and (z, y, x);
in SCICO, the dimensions go (x, y, z).
Args:
input_shape: Shape of input image.
output_shape: Shape of output (detector).
str: Sequence of axes for rotation. Up to 3 characters belonging to the set {'X', 'Y', 'Z'}
for intrinsic rotations, or {'x', 'y', 'z'} for extrinsic rotations. Extrinsic and
intrinsic rotations cannot be mixed in one function call.
angles: (num_angles, N), N = 1, 2, or 3 Euler angles.
degrees: If ``True``, angles are in degrees, otherwise radians. Default: ``True``, radians.
Returns:
(num_angles, 2, 4) array of homogeneous projection matrices.
"""

For 2D arrays,
in Astra, the dimensions go (rows, columns) and (y, x);
in SCICO, the dimensions go (x, y).
# make projection matrix: form a rotation matrix and chop off the last row
matrices = jnp.stack(
[Rotation.from_euler(seq, angles_i, degrees=degrees).as_matrix() for angles_i in angles]
)
matrices = matrices[:, :2, :]

In Astra, the x-grid (recon) is centered on the origin and the y-grid (projection) can move.
In SCICO, the x-grid origin is x[0, 0, 0], the y-grid origin is y[0, 0].
# add translation
x0 = jnp.array(input_shape) / 2
t = -jnp.tensordot(matrices, x0, axes=[2, 0]) + jnp.array(output_shape) / 2
matrices = jnp.concatenate((matrices, t[..., np.newaxis]), axis=2)

See https://astra-toolbox.com/docs/geom3d.html#projection-geometries parallel3d_vec.
"""
# ray is perpendicular to projection axes
ray = np.cross(P[:, 0, :3], P[:, 1, :3])
# detector center comes from lifting the center index to 3D
y_center = np.array(det_shape) / 2
x_center = np.einsum("...mn,n->...m", P[..., :3], np.array(in_shape) / 2) + P[..., 3]
d = np.einsum("...mn,...m->...n", P[..., :3], y_center - x_center) # (V, 2, 3) x (V, 2)
u = -P[:, 1, :3]
v = -P[:, 0, :3]
vectors = np.concatenate((ray, d, u, v), axis=1) # (v, 12)
return vectors


def astra_to_scico(vol_geom, proj_geom):
"""
Convert ASTRA volume and projection geometry into a SCICO X-ray projection matrix.
"""
in_shape = (vol_geom["GridSliceCount"], vol_geom["GridRowCount"], vol_geom["GridColCount"])
det_shape = (proj_geom["DetectorRowCount"], proj_geom["DetectorColCount"])
vectors = proj_geom["Vectors"]
_, d, u, v = vectors[:, 0:3], vectors[:, 3:6], vectors[:, 6:9], vectors[:, 9:12]
P = -np.stack((v, u), axis=1)
center_diff = np.einsum("...mn,...n->...m", P, d) # y_center - x_center
y_center = np.array(det_shape) / 2
Px_center_t = -(center_diff - y_center)
Px_center = np.einsum("...mn,n->...m", P, np.array(in_shape) / 2)
t = Px_center_t - Px_center
P = np.concatenate((P, t[..., np.newaxis]), axis=2)
return P
return matrices
68 changes: 68 additions & 0 deletions scico/linop/xray/astra.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import numpy as np

import jax
from jax.typing import ArrayLike

try:
import astra
Expand Down Expand Up @@ -54,6 +55,73 @@ def set_astra_gpu_index(idx: Union[int, Sequence[int]]):
astra.set_gpu_index(idx)


def convert_from_scico_geometry(
in_shape: Shape, matrices: ArrayLike, det_shape: Shape
) -> ArrayLike:
"""
Convert SCICO projection matrices into ASTRA "parallel3d_vec" vectors.
For 3D arrays,
in Astra, the dimensions go (slices, rows, columns) and (z, y, x);
in SCICO, the dimensions go (x, y, z).
In Astra, the x-grid (recon) is centered on the origin and the y-grid (projection) can move.
In SCICO, the x-grid origin is the center of x[0, 0, 0], the y-grid origin is the center
of y[0, 0].
See https://astra-toolbox.com/docs/geom3d.html#projection-geometries parallel3d_vec.
Args:
in_shape: Shape of input image.
matrices: (num_angles, 2, 4) array of homogeneous projection matrices.
det_shape: Shape of detector.
Returns:
(num_angles, 12) vector array in the ASTRA "parallel3d_vec" convention.
"""
# ray is perpendicular to projection axes
ray = np.cross(matrices[:, 0, :3], matrices[:, 1, :3])
# detector center comes from lifting the center index to 3D
y_center = np.array(det_shape) / 2
x_center = (
np.einsum("...mn,n->...m", matrices[..., :3], np.array(in_shape) / 2) + matrices[..., 3]
)
d = np.einsum("...mn,...m->...n", matrices[..., :3], y_center - x_center) # (V, 2, 3) x (V, 2)
u = -matrices[:, 1, :3]
v = -matrices[:, 0, :3]
vectors = np.concatenate((ray, d, u, v), axis=1) # (v, 12)
return vectors


def convert_to_scico_geometry(vol_geom, proj_geom):
"""
Convert ASTRA volume and projection geometry into a SCICO X-ray projection matrix, assuming
"parallel3d_vec" format.
Args:
vol_geom: ASTRA volume geometry object.
proj_geom: ASTRA projection geometry object.
Returns:
(num_angles, 2, 4) array of homogeneous projection matrices.
"""
in_shape = (vol_geom["GridSliceCount"], vol_geom["GridRowCount"], vol_geom["GridColCount"])
det_shape = (proj_geom["DetectorRowCount"], proj_geom["DetectorColCount"])
vectors = proj_geom["Vectors"]
_, d, u, v = vectors[:, 0:3], vectors[:, 3:6], vectors[:, 6:9], vectors[:, 9:12]
matrices = -np.stack((v, u), axis=1)
center_diff = np.einsum("...mn,...n->...m", matrices, d) # y_center - x_center
y_center = np.array(det_shape) / 2
Px_center_t = -(center_diff - y_center)
Px_center = np.einsum("...mn,n->...m", matrices, np.array(in_shape) / 2)
t = Px_center_t - Px_center
matrices = np.concatenate((matrices, t[..., np.newaxis]), axis=2)

return matrices


class XRayTransform2D(LinearOperator):
r"""2D parallel beam X-ray transform based on the ASTRA toolbox.
Expand Down

0 comments on commit c80f3b8

Please sign in to comment.