diff --git a/CHANGES.rst b/CHANGES.rst index a32c3854..8a786484 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -7,8 +7,7 @@ Version 0.0.6 (unreleased) ---------------------------- • Significant changes to ``linop.xray.astra`` API. -• New integrated 3D X-ray transform via ``linop.XRayTransform`` and - ``linop.Parallel3dProjector``. +• New integrated 3D X-ray transform via ``linop.xray.XRayTransform3D``. • New functional ``functional.IsotropicTVNorm`` and faster implementation of ``functional.AnisotropicTVNorm``. • New linear operators ``linop.ProjectedGradient``, ``linop.PolarGradient``, diff --git a/examples/scripts/ct_large_projection.py b/examples/scripts/ct_large_projection.py index 9281b68e..814ab7cb 100644 --- a/examples/scripts/ct_large_projection.py +++ b/examples/scripts/ct_large_projection.py @@ -18,7 +18,7 @@ import jax from scico.examples import create_block_phantom -from scico.linop import Parallel3dProjector +from scico.linop import XRayTransform3D N = 1000 num_views = 10 @@ -31,12 +31,12 @@ rot_X = 90.0 - 16.0 rot_Y = np.linspace(0, 180, num_views, endpoint=False) angles = np.stack(np.broadcast_arrays(rot_X, rot_Y), axis=-1) -matrices = Parallel3dProjector.matrices_from_euler_angles( +matrices = XRayTransform3D.matrices_from_euler_angles( in_shape, det_shape, "XY", angles, degrees=True ) -H = Parallel3dProjector(in_shape, matrices, det_shape) +H = XRayTransform3D(in_shape, matrices, det_shape) proj = H @ x jax.block_until_ready(proj) diff --git a/examples/scripts/ct_multi_tv_admm.py b/examples/scripts/ct_multi_tv_admm.py index 169b9eba..95711679 100644 --- a/examples/scripts/ct_multi_tv_admm.py +++ b/examples/scripts/ct_multi_tv_admm.py @@ -27,7 +27,7 @@ import scico.numpy as snp from scico import functional, linop, loss, metric, plot -from scico.linop.xray import Parallel2dProjector, astra, svmbir +from scico.linop.xray import XRayTransform2D, astra, svmbir from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.util import device_info @@ -54,7 +54,7 @@ "svmbir": svmbir.XRayTransform( x_gt.shape, 2 * np.pi - angles, det_count, delta_pixel=1.0, delta_channel=det_spacing ), # svmbir - "scico": Parallel2dProjector((N, N), angles, det_count=det_count, dx=1 / det_spacing), # scico + "scico": XRayTransform2D((N, N), angles, det_count=det_count, dx=1 / det_spacing), # scico } diff --git a/examples/scripts/ct_projector_comparison_2d.py b/examples/scripts/ct_projector_comparison_2d.py index 19352626..b54810c5 100644 --- a/examples/scripts/ct_projector_comparison_2d.py +++ b/examples/scripts/ct_projector_comparison_2d.py @@ -22,7 +22,7 @@ import scico.linop.xray.astra as astra from scico import plot -from scico.linop import Parallel2dProjector +from scico.linop import XRayTransform2D from scico.util import Timer """ @@ -46,7 +46,7 @@ projectors = {} timer.start("scico_init") -projectors["scico"] = Parallel2dProjector((N, N), angles) +projectors["scico"] = XRayTransform2D((N, N), angles) timer.stop("scico_init") timer.start("astra_init") diff --git a/examples/scripts/ct_projector_comparison_3d.py b/examples/scripts/ct_projector_comparison_3d.py index 6fd36401..2752b9b9 100644 --- a/examples/scripts/ct_projector_comparison_3d.py +++ b/examples/scripts/ct_projector_comparison_3d.py @@ -22,7 +22,7 @@ import scico.linop.xray.astra as astra from scico import plot from scico.examples import create_block_phantom -from scico.linop import Parallel3dProjector, XRayTransform +from scico.linop import XRayTransform, XRayTransform3D from scico.util import ContextTimer, Timer """ @@ -47,7 +47,7 @@ rot_X = 90.0 - 16.0 rot_Y = np.linspace(0, 180, num_angles, endpoint=False) angles = np.stack(np.broadcast_arrays(rot_X, rot_Y), axis=-1) -matrices = Parallel3dProjector.matrices_from_euler_angles( +matrices = XRayTransform3D.matrices_from_euler_angles( in_shape, out_shape, "XY", angles, degrees=True ) @@ -58,7 +58,7 @@ timer_scico = Timer() with ContextTimer(timer_scico, "init"): - H_scico = XRayTransform(Parallel3dProjector(in_shape, matrices, out_shape)) + H_scico = XRayTransform(XRayTransform3D(in_shape, matrices, out_shape)) with ContextTimer(timer_scico, "first_fwd"): y_scico = H_scico @ x @@ -133,7 +133,7 @@ """ P_from_astra = scico.linop.xray.astra.convert_to_scico_geometry(H_astra.vol_geom, H_astra.proj_geom) -H_scico_from_astra = XRayTransform(Parallel3dProjector(in_shape, P_from_astra, out_shape)) +H_scico_from_astra = XRayTransform(XRayTransform3D(in_shape, P_from_astra, out_shape)) y_scico_from_astra = H_scico_from_astra @ x HTy_scico_from_astra = H_scico_from_astra.T @ y_scico_from_astra diff --git a/examples/scripts/ct_tv_admm.py b/examples/scripts/ct_tv_admm.py index da80e180..ec48d4ea 100644 --- a/examples/scripts/ct_tv_admm.py +++ b/examples/scripts/ct_tv_admm.py @@ -29,7 +29,7 @@ import scico.numpy as snp from scico import functional, linop, loss, metric, plot -from scico.linop.xray import Parallel2dProjector +from scico.linop.xray import XRayTransform2D from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.util import device_info @@ -46,7 +46,7 @@ """ n_projection = 45 # number of projections angles = np.linspace(0, np.pi, n_projection) + np.pi / 2.0 # evenly spaced projection angles -A = Parallel2dProjector((N, N), angles) # CT projection operator +A = XRayTransform2D((N, N), angles) # CT projection operator y = A @ x_gt # sinogram diff --git a/scico/linop/__init__.py b/scico/linop/__init__.py index 8c465c81..2d6f5cd5 100644 --- a/scico/linop/__init__.py +++ b/scico/linop/__init__.py @@ -25,7 +25,6 @@ from ._matrix import MatrixOperator from ._stack import DiagonalReplicated, DiagonalStack, VerticalStack, linop_over_axes from ._util import jacobian, operator_norm, power_iteration, valid_adjoint -from .xray import Parallel2dProjector, Parallel3dProjector __all__ = [ "CircularConvolve", @@ -51,8 +50,6 @@ "Sum", "Transpose", "LinearOperator", - "Parallel2dProjector", - "Parallel3dProjector", "ComposedLinearOperator", "linop_from_function", "linop_over_axes", diff --git a/scico/linop/xray/__init__.py b/scico/linop/xray/__init__.py index f00a0a5e..4b57161b 100644 --- a/scico/linop/xray/__init__.py +++ b/scico/linop/xray/__init__.py @@ -46,9 +46,9 @@ """ -from ._xray import Parallel2dProjector, Parallel3dProjector +from ._xray import XRayTransform2D, XRayTransform3D __all__ = [ - "Parallel2dProjector", - "Parallel3dProjector", + "XRayTransform2D", + "XRayTransform3D", ] diff --git a/scico/linop/xray/_xray.py b/scico/linop/xray/_xray.py index e58ffd7f..3ccd5d96 100644 --- a/scico/linop/xray/_xray.py +++ b/scico/linop/xray/_xray.py @@ -25,7 +25,7 @@ from .._linop import LinearOperator -class Parallel2dProjector(LinearOperator): +class XRayTransform2D(LinearOperator): """Parallel ray, single axis, 2D X-ray projector. This implementation approximates the projection of each rectangular @@ -117,11 +117,11 @@ def __init__( def project(self, im): """Compute X-ray projection.""" - return Parallel2dProjector._project(im, self.x0, self.dx, self.y0, self.ny, self.angles) + return XRayTransform2D._project(im, self.x0, self.dx, self.y0, self.ny, self.angles) def back_project(self, y): """Compute X-ray back projection""" - return Parallel2dProjector._back_project(y, self.x0, self.dx, self.nx, self.y0, self.angles) + return XRayTransform2D._back_project(y, self.x0, self.dx, self.nx, self.y0, self.angles) @staticmethod @partial(jax.jit, static_argnames=["ny"]) @@ -138,7 +138,7 @@ def _project(im, x0, dx, y0, ny, angles): projected onto unit vectors pointing in these directions. """ nx = im.shape - inds, weights = Parallel2dProjector._calc_weights(x0, dx, nx, angles, y0) + inds, weights = XRayTransform2D._calc_weights(x0, dx, nx, angles, y0) # Handle out of bounds indices. In the .at call, inds >= y0 are # ignored, while inds < 0 wrap around. So we set inds < 0 to ny. inds = jnp.where(inds >= 0, inds, ny) @@ -168,7 +168,7 @@ def _back_project(y, x0, dx, nx, y0, angles): projected onto units vectors pointing in these directions. """ ny = y.shape[1] - inds, weights = Parallel2dProjector._calc_weights(x0, dx, nx, angles, y0) + inds, weights = XRayTransform2D._calc_weights(x0, dx, nx, angles, y0) # Handle out of bounds indices. In the .at call, inds >= y0 are # ignored, while inds < 0 wrap around. So we set inds < 0 to ny. inds = jnp.where(inds >= 0, inds, ny) @@ -225,7 +225,7 @@ def _calc_weights(x0, dx, nx, angle, y0): return inds, weights -class Parallel3dProjector(LinearOperator): +class XRayTransform3D(LinearOperator): r"""General-purpose, 3D, parallel ray X-ray projector. For each view, the projection geometry is specified by an array @@ -242,7 +242,7 @@ class Parallel3dProjector(LinearOperator): The detector pixel at index `(i, j)` covers detector coordinates :math:`[i+1) \times [j+1)`. - :meth:`Parallel3dProjector.matrices_from_euler_angles` can help to + :meth:`XRayTransform3D.matrices_from_euler_angles` can help to make these geometry arrays. @@ -277,11 +277,11 @@ def __init__( def project(self, im): """Compute X-ray projection.""" - return Parallel3dProjector._project(im, self.matrices, self.det_shape) + return XRayTransform3D._project(im, self.matrices, self.det_shape) def back_project(self, proj): """Compute X-ray back projection""" - return Parallel3dProjector._back_project(proj, self.matrices, self.input_shape) + return XRayTransform3D._back_project(proj, self.matrices, self.input_shape) @staticmethod def _project(im: ArrayLike, matrices: ArrayLike, det_shape: Shape) -> ArrayLike: @@ -299,7 +299,7 @@ def _project(im: ArrayLike, matrices: ArrayLike, det_shape: Shape) -> ArrayLike: for view_ind, matrix in enumerate(matrices): for slice_offset in slice_offsets: proj = proj.at[view_ind].set( - Parallel3dProjector._project_single( + XRayTransform3D._project_single( im[slice_offset : slice_offset + MAX_SLICE_LEN], matrix, proj[view_ind], @@ -320,7 +320,7 @@ def _project_single( det_shape: Shape of detector. """ - ul_ind, ul_weight, ur_weight, ll_weight, lr_weight = Parallel3dProjector._calc_weights( + ul_ind, ul_weight, ur_weight, ll_weight, lr_weight = XRayTransform3D._calc_weights( im.shape, matrix, proj.shape, slice_offset ) proj = proj.at[ul_ind[0], ul_ind[1]].add(ul_weight * im, mode="drop") @@ -344,7 +344,7 @@ def _back_project(proj: ArrayLike, matrices: ArrayLike, input_shape: Shape) -> A for view_ind, matrix in enumerate(matrices): for slice_offset in slice_offsets: HTy = HTy.at[slice_offset : slice_offset + MAX_SLICE_LEN].set( - Parallel3dProjector._back_project_single( + XRayTransform3D._back_project_single( proj[view_ind], matrix, HTy[slice_offset : slice_offset + MAX_SLICE_LEN], @@ -360,7 +360,7 @@ def _back_project(proj: ArrayLike, matrices: ArrayLike, input_shape: Shape) -> A def _back_project_single( y: ArrayLike, matrix: ArrayLike, HTy: ArrayLike, slice_offset: int = 0 ) -> ArrayLike: - ul_ind, ul_weight, ur_weight, ll_weight, lr_weight = Parallel3dProjector._calc_weights( + ul_ind, ul_weight, ur_weight, ll_weight, lr_weight = XRayTransform3D._calc_weights( HTy.shape, matrix, y.shape, slice_offset ) HTy = HTy + y[ul_ind[0], ul_ind[1]] * ul_weight diff --git a/scico/test/linop/xray/test_xray.py b/scico/test/linop/xray/test_xray.py index 5adbd84e..cd7c0dcd 100644 --- a/scico/test/linop/xray/test_xray.py +++ b/scico/test/linop/xray/test_xray.py @@ -5,7 +5,7 @@ import pytest import scico -from scico.linop import Parallel2dProjector, Parallel3dProjector +from scico.linop.xray import XRayTransform2D, XRayTransform3D @pytest.mark.filterwarnings("error") @@ -13,18 +13,18 @@ def test_init(): input_shape = (3, 3) # no warning with default settings, even at 45 degrees - H = Parallel2dProjector(input_shape, jnp.array([jnp.pi / 4])) + H = XRayTransform2D(input_shape, jnp.array([jnp.pi / 4])) # no warning if we project orthogonally with oversized pixels - H = Parallel2dProjector(input_shape, jnp.array([0]), dx=jnp.array([1, 1])) + H = XRayTransform2D(input_shape, jnp.array([0]), dx=jnp.array([1, 1])) # warning if the projection angle changes with pytest.warns(UserWarning): - H = Parallel2dProjector(input_shape, jnp.array([0.1]), dx=jnp.array([1.1, 1.1])) + H = XRayTransform2D(input_shape, jnp.array([0.1]), dx=jnp.array([1.1, 1.1])) # warning if the pixels get any larger with pytest.warns(UserWarning): - H = Parallel2dProjector(input_shape, jnp.array([0]), dx=jnp.array([1.1, 1.1])) + H = XRayTransform2D(input_shape, jnp.array([0]), dx=jnp.array([1.1, 1.1])) def test_apply(): @@ -35,13 +35,13 @@ def test_apply(): angles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False) # general projection - H = Parallel2dProjector(x.shape, angles) + H = XRayTransform2D(x.shape, angles) y = H @ x assert y.shape[0] == (num_angles) # fixed det_count det_count = 14 - H = Parallel2dProjector(x.shape, angles, det_count=det_count) + H = XRayTransform2D(x.shape, angles, det_count=det_count) y = H @ x assert y.shape[1] == det_count @@ -54,7 +54,7 @@ def test_apply_adjoint(): angles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False) # general projection - H = Parallel2dProjector(x.shape, angles) + H = XRayTransform2D(x.shape, angles) y = H @ x assert y.shape[0] == (num_angles) @@ -66,7 +66,7 @@ def test_apply_adjoint(): # fixed det_length det_count = 14 - H = Parallel2dProjector(x.shape, angles, det_count=det_count) + H = XRayTransform2D(x.shape, angles, det_count=det_count) y = H @ x assert y.shape[1] == det_count @@ -79,8 +79,8 @@ def test_3d_scaling(): output_shape = x.shape[:2] # default spacing - M = Parallel3dProjector.matrices_from_euler_angles(input_shape, output_shape, "X", [0.0]) - H = Parallel3dProjector(input_shape, matrices=M, det_shape=output_shape) + M = XRayTransform3D.matrices_from_euler_angles(input_shape, output_shape, "X", [0.0]) + H = XRayTransform3D(input_shape, matrices=M, det_shape=output_shape) # fmt: off truth = jnp.array( [[[0.0, 0.0, 0.0, 0.0], @@ -91,10 +91,10 @@ def test_3d_scaling(): np.testing.assert_allclose(H @ x, truth) # bigger voxels in the x (first index) direction - M = Parallel3dProjector.matrices_from_euler_angles( + M = XRayTransform3D.matrices_from_euler_angles( input_shape, output_shape, "X", [0.0], voxel_spacing=[2.0, 1.0, 1.0] ) - H = Parallel3dProjector(input_shape, matrices=M, det_shape=output_shape) + H = XRayTransform3D(input_shape, matrices=M, det_shape=output_shape) # fmt: off truth = jnp.array( [[[0. , 0.5, 0.5, 0. ], @@ -105,10 +105,10 @@ def test_3d_scaling(): np.testing.assert_allclose(H @ x, truth) # bigger detector pixels in the x (first index) direction - M = Parallel3dProjector.matrices_from_euler_angles( + M = XRayTransform3D.matrices_from_euler_angles( input_shape, output_shape, "X", [0.0], det_spacing=[2.0, 1.0] ) - H = Parallel3dProjector(input_shape, matrices=M, det_shape=output_shape) + H = XRayTransform3D(input_shape, matrices=M, det_shape=output_shape) # fmt: off truth = None # fmt: on # TODO: Check this case more closely. # np.testing.assert_allclose(H @ x, truth)