Skip to content

Commit

Permalink
Rename new projectors to XRayTransform2D and XRayTransform3D
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael-T-McCann committed Sep 9, 2024
1 parent 854d3ad commit eee399f
Show file tree
Hide file tree
Showing 10 changed files with 45 additions and 49 deletions.
3 changes: 1 addition & 2 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ Version 0.0.6 (unreleased)
----------------------------

• Significant changes to ``linop.xray.astra`` API.
• New integrated 3D X-ray transform via ``linop.XRayTransform`` and
``linop.Parallel3dProjector``.
• New integrated 3D X-ray transform via ``linop.xray.XRayTransform3D``.
• New functional ``functional.IsotropicTVNorm`` and faster implementation
of ``functional.AnisotropicTVNorm``.
• New linear operators ``linop.ProjectedGradient``, ``linop.PolarGradient``,
Expand Down
6 changes: 3 additions & 3 deletions examples/scripts/ct_large_projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import jax

from scico.examples import create_block_phantom
from scico.linop import Parallel3dProjector
from scico.linop import XRayTransform3D

N = 1000
num_views = 10
Expand All @@ -31,12 +31,12 @@
rot_X = 90.0 - 16.0
rot_Y = np.linspace(0, 180, num_views, endpoint=False)
angles = np.stack(np.broadcast_arrays(rot_X, rot_Y), axis=-1)
matrices = Parallel3dProjector.matrices_from_euler_angles(
matrices = XRayTransform3D.matrices_from_euler_angles(
in_shape, det_shape, "XY", angles, degrees=True
)


H = Parallel3dProjector(in_shape, matrices, det_shape)
H = XRayTransform3D(in_shape, matrices, det_shape)

proj = H @ x
jax.block_until_ready(proj)
4 changes: 2 additions & 2 deletions examples/scripts/ct_multi_tv_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

import scico.numpy as snp
from scico import functional, linop, loss, metric, plot
from scico.linop.xray import Parallel2dProjector, astra, svmbir
from scico.linop.xray import XRayTransform2D, astra, svmbir
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.util import device_info

Expand All @@ -54,7 +54,7 @@
"svmbir": svmbir.XRayTransform(
x_gt.shape, 2 * np.pi - angles, det_count, delta_pixel=1.0, delta_channel=det_spacing
), # svmbir
"scico": Parallel2dProjector((N, N), angles, det_count=det_count, dx=1 / det_spacing), # scico
"scico": XRayTransform2D((N, N), angles, det_count=det_count, dx=1 / det_spacing), # scico
}


Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/ct_projector_comparison_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import scico.linop.xray.astra as astra
from scico import plot
from scico.linop import Parallel2dProjector
from scico.linop import XRayTransform2D
from scico.util import Timer

"""
Expand All @@ -46,7 +46,7 @@

projectors = {}
timer.start("scico_init")
projectors["scico"] = Parallel2dProjector((N, N), angles)
projectors["scico"] = XRayTransform2D((N, N), angles)
timer.stop("scico_init")

timer.start("astra_init")
Expand Down
8 changes: 4 additions & 4 deletions examples/scripts/ct_projector_comparison_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import scico.linop.xray.astra as astra
from scico import plot
from scico.examples import create_block_phantom
from scico.linop import Parallel3dProjector, XRayTransform
from scico.linop import XRayTransform, XRayTransform3D
from scico.util import ContextTimer, Timer

"""
Expand All @@ -47,7 +47,7 @@
rot_X = 90.0 - 16.0
rot_Y = np.linspace(0, 180, num_angles, endpoint=False)
angles = np.stack(np.broadcast_arrays(rot_X, rot_Y), axis=-1)
matrices = Parallel3dProjector.matrices_from_euler_angles(
matrices = XRayTransform3D.matrices_from_euler_angles(
in_shape, out_shape, "XY", angles, degrees=True
)

Expand All @@ -58,7 +58,7 @@

timer_scico = Timer()
with ContextTimer(timer_scico, "init"):
H_scico = XRayTransform(Parallel3dProjector(in_shape, matrices, out_shape))
H_scico = XRayTransform(XRayTransform3D(in_shape, matrices, out_shape))

with ContextTimer(timer_scico, "first_fwd"):
y_scico = H_scico @ x
Expand Down Expand Up @@ -133,7 +133,7 @@
"""

P_from_astra = scico.linop.xray.astra.convert_to_scico_geometry(H_astra.vol_geom, H_astra.proj_geom)
H_scico_from_astra = XRayTransform(Parallel3dProjector(in_shape, P_from_astra, out_shape))
H_scico_from_astra = XRayTransform(XRayTransform3D(in_shape, P_from_astra, out_shape))

y_scico_from_astra = H_scico_from_astra @ x
HTy_scico_from_astra = H_scico_from_astra.T @ y_scico_from_astra
Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/ct_tv_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

import scico.numpy as snp
from scico import functional, linop, loss, metric, plot
from scico.linop.xray import Parallel2dProjector
from scico.linop.xray import XRayTransform2D
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.util import device_info

Expand All @@ -46,7 +46,7 @@
"""
n_projection = 45 # number of projections
angles = np.linspace(0, np.pi, n_projection) + np.pi / 2.0 # evenly spaced projection angles
A = Parallel2dProjector((N, N), angles) # CT projection operator
A = XRayTransform2D((N, N), angles) # CT projection operator
y = A @ x_gt # sinogram


Expand Down
3 changes: 0 additions & 3 deletions scico/linop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from ._matrix import MatrixOperator
from ._stack import DiagonalReplicated, DiagonalStack, VerticalStack, linop_over_axes
from ._util import jacobian, operator_norm, power_iteration, valid_adjoint
from .xray import Parallel2dProjector, Parallel3dProjector

__all__ = [
"CircularConvolve",
Expand All @@ -51,8 +50,6 @@
"Sum",
"Transpose",
"LinearOperator",
"Parallel2dProjector",
"Parallel3dProjector",
"ComposedLinearOperator",
"linop_from_function",
"linop_over_axes",
Expand Down
6 changes: 3 additions & 3 deletions scico/linop/xray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@
"""


from ._xray import Parallel2dProjector, Parallel3dProjector
from ._xray import XRayTransform2D, XRayTransform3D

__all__ = [
"Parallel2dProjector",
"Parallel3dProjector",
"XRayTransform2D",
"XRayTransform3D",
]
26 changes: 13 additions & 13 deletions scico/linop/xray/_xray.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from .._linop import LinearOperator


class Parallel2dProjector(LinearOperator):
class XRayTransform2D(LinearOperator):
"""Parallel ray, single axis, 2D X-ray projector.
This implementation approximates the projection of each rectangular
Expand Down Expand Up @@ -117,11 +117,11 @@ def __init__(

def project(self, im):
"""Compute X-ray projection."""
return Parallel2dProjector._project(im, self.x0, self.dx, self.y0, self.ny, self.angles)
return XRayTransform2D._project(im, self.x0, self.dx, self.y0, self.ny, self.angles)

def back_project(self, y):
"""Compute X-ray back projection"""
return Parallel2dProjector._back_project(y, self.x0, self.dx, self.nx, self.y0, self.angles)
return XRayTransform2D._back_project(y, self.x0, self.dx, self.nx, self.y0, self.angles)

@staticmethod
@partial(jax.jit, static_argnames=["ny"])
Expand All @@ -138,7 +138,7 @@ def _project(im, x0, dx, y0, ny, angles):
projected onto unit vectors pointing in these directions.
"""
nx = im.shape
inds, weights = Parallel2dProjector._calc_weights(x0, dx, nx, angles, y0)
inds, weights = XRayTransform2D._calc_weights(x0, dx, nx, angles, y0)
# Handle out of bounds indices. In the .at call, inds >= y0 are
# ignored, while inds < 0 wrap around. So we set inds < 0 to ny.
inds = jnp.where(inds >= 0, inds, ny)
Expand Down Expand Up @@ -168,7 +168,7 @@ def _back_project(y, x0, dx, nx, y0, angles):
projected onto units vectors pointing in these directions.
"""
ny = y.shape[1]
inds, weights = Parallel2dProjector._calc_weights(x0, dx, nx, angles, y0)
inds, weights = XRayTransform2D._calc_weights(x0, dx, nx, angles, y0)
# Handle out of bounds indices. In the .at call, inds >= y0 are
# ignored, while inds < 0 wrap around. So we set inds < 0 to ny.
inds = jnp.where(inds >= 0, inds, ny)
Expand Down Expand Up @@ -225,7 +225,7 @@ def _calc_weights(x0, dx, nx, angle, y0):
return inds, weights


class Parallel3dProjector(LinearOperator):
class XRayTransform3D(LinearOperator):
r"""General-purpose, 3D, parallel ray X-ray projector.
For each view, the projection geometry is specified by an array
Expand All @@ -242,7 +242,7 @@ class Parallel3dProjector(LinearOperator):
The detector pixel at index `(i, j)` covers detector coordinates
:math:`[i+1) \times [j+1)`.
:meth:`Parallel3dProjector.matrices_from_euler_angles` can help to
:meth:`XRayTransform3D.matrices_from_euler_angles` can help to
make these geometry arrays.
Expand Down Expand Up @@ -277,11 +277,11 @@ def __init__(

def project(self, im):
"""Compute X-ray projection."""
return Parallel3dProjector._project(im, self.matrices, self.det_shape)
return XRayTransform3D._project(im, self.matrices, self.det_shape)

def back_project(self, proj):
"""Compute X-ray back projection"""
return Parallel3dProjector._back_project(proj, self.matrices, self.input_shape)
return XRayTransform3D._back_project(proj, self.matrices, self.input_shape)

@staticmethod
def _project(im: ArrayLike, matrices: ArrayLike, det_shape: Shape) -> ArrayLike:
Expand All @@ -299,7 +299,7 @@ def _project(im: ArrayLike, matrices: ArrayLike, det_shape: Shape) -> ArrayLike:
for view_ind, matrix in enumerate(matrices):
for slice_offset in slice_offsets:
proj = proj.at[view_ind].set(
Parallel3dProjector._project_single(
XRayTransform3D._project_single(
im[slice_offset : slice_offset + MAX_SLICE_LEN],
matrix,
proj[view_ind],
Expand All @@ -320,7 +320,7 @@ def _project_single(
det_shape: Shape of detector.
"""

ul_ind, ul_weight, ur_weight, ll_weight, lr_weight = Parallel3dProjector._calc_weights(
ul_ind, ul_weight, ur_weight, ll_weight, lr_weight = XRayTransform3D._calc_weights(
im.shape, matrix, proj.shape, slice_offset
)
proj = proj.at[ul_ind[0], ul_ind[1]].add(ul_weight * im, mode="drop")
Expand All @@ -344,7 +344,7 @@ def _back_project(proj: ArrayLike, matrices: ArrayLike, input_shape: Shape) -> A
for view_ind, matrix in enumerate(matrices):
for slice_offset in slice_offsets:
HTy = HTy.at[slice_offset : slice_offset + MAX_SLICE_LEN].set(
Parallel3dProjector._back_project_single(
XRayTransform3D._back_project_single(
proj[view_ind],
matrix,
HTy[slice_offset : slice_offset + MAX_SLICE_LEN],
Expand All @@ -360,7 +360,7 @@ def _back_project(proj: ArrayLike, matrices: ArrayLike, input_shape: Shape) -> A
def _back_project_single(
y: ArrayLike, matrix: ArrayLike, HTy: ArrayLike, slice_offset: int = 0
) -> ArrayLike:
ul_ind, ul_weight, ur_weight, ll_weight, lr_weight = Parallel3dProjector._calc_weights(
ul_ind, ul_weight, ur_weight, ll_weight, lr_weight = XRayTransform3D._calc_weights(
HTy.shape, matrix, y.shape, slice_offset
)
HTy = HTy + y[ul_ind[0], ul_ind[1]] * ul_weight
Expand Down
30 changes: 15 additions & 15 deletions scico/test/linop/xray/test_xray.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,26 @@
import pytest

import scico
from scico.linop import Parallel2dProjector, Parallel3dProjector
from scico.linop.xray import XRayTransform2D, XRayTransform3D


@pytest.mark.filterwarnings("error")
def test_init():
input_shape = (3, 3)

# no warning with default settings, even at 45 degrees
H = Parallel2dProjector(input_shape, jnp.array([jnp.pi / 4]))
H = XRayTransform2D(input_shape, jnp.array([jnp.pi / 4]))

# no warning if we project orthogonally with oversized pixels
H = Parallel2dProjector(input_shape, jnp.array([0]), dx=jnp.array([1, 1]))
H = XRayTransform2D(input_shape, jnp.array([0]), dx=jnp.array([1, 1]))

# warning if the projection angle changes
with pytest.warns(UserWarning):
H = Parallel2dProjector(input_shape, jnp.array([0.1]), dx=jnp.array([1.1, 1.1]))
H = XRayTransform2D(input_shape, jnp.array([0.1]), dx=jnp.array([1.1, 1.1]))

# warning if the pixels get any larger
with pytest.warns(UserWarning):
H = Parallel2dProjector(input_shape, jnp.array([0]), dx=jnp.array([1.1, 1.1]))
H = XRayTransform2D(input_shape, jnp.array([0]), dx=jnp.array([1.1, 1.1]))


def test_apply():
Expand All @@ -35,13 +35,13 @@ def test_apply():
angles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False)

# general projection
H = Parallel2dProjector(x.shape, angles)
H = XRayTransform2D(x.shape, angles)
y = H @ x
assert y.shape[0] == (num_angles)

# fixed det_count
det_count = 14
H = Parallel2dProjector(x.shape, angles, det_count=det_count)
H = XRayTransform2D(x.shape, angles, det_count=det_count)
y = H @ x
assert y.shape[1] == det_count

Expand All @@ -54,7 +54,7 @@ def test_apply_adjoint():
angles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False)

# general projection
H = Parallel2dProjector(x.shape, angles)
H = XRayTransform2D(x.shape, angles)
y = H @ x
assert y.shape[0] == (num_angles)

Expand All @@ -66,7 +66,7 @@ def test_apply_adjoint():

# fixed det_length
det_count = 14
H = Parallel2dProjector(x.shape, angles, det_count=det_count)
H = XRayTransform2D(x.shape, angles, det_count=det_count)
y = H @ x
assert y.shape[1] == det_count

Expand All @@ -79,8 +79,8 @@ def test_3d_scaling():
output_shape = x.shape[:2]

# default spacing
M = Parallel3dProjector.matrices_from_euler_angles(input_shape, output_shape, "X", [0.0])
H = Parallel3dProjector(input_shape, matrices=M, det_shape=output_shape)
M = XRayTransform3D.matrices_from_euler_angles(input_shape, output_shape, "X", [0.0])
H = XRayTransform3D(input_shape, matrices=M, det_shape=output_shape)
# fmt: off
truth = jnp.array(
[[[0.0, 0.0, 0.0, 0.0],
Expand All @@ -91,10 +91,10 @@ def test_3d_scaling():
np.testing.assert_allclose(H @ x, truth)

# bigger voxels in the x (first index) direction
M = Parallel3dProjector.matrices_from_euler_angles(
M = XRayTransform3D.matrices_from_euler_angles(
input_shape, output_shape, "X", [0.0], voxel_spacing=[2.0, 1.0, 1.0]
)
H = Parallel3dProjector(input_shape, matrices=M, det_shape=output_shape)
H = XRayTransform3D(input_shape, matrices=M, det_shape=output_shape)
# fmt: off
truth = jnp.array(
[[[0. , 0.5, 0.5, 0. ],
Expand All @@ -105,10 +105,10 @@ def test_3d_scaling():
np.testing.assert_allclose(H @ x, truth)

# bigger detector pixels in the x (first index) direction
M = Parallel3dProjector.matrices_from_euler_angles(
M = XRayTransform3D.matrices_from_euler_angles(
input_shape, output_shape, "X", [0.0], det_spacing=[2.0, 1.0]
)
H = Parallel3dProjector(input_shape, matrices=M, det_shape=output_shape)
H = XRayTransform3D(input_shape, matrices=M, det_shape=output_shape)
# fmt: off
truth = None # fmt: on # TODO: Check this case more closely.
# np.testing.assert_allclose(H @ x, truth)

0 comments on commit eee399f

Please sign in to comment.