Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix 2D X-ray projector bugs #537

Merged
merged 5 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 39 additions & 12 deletions examples/scripts/ct_multi_cs_tv_admm.py
bwohlberg marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -38,41 +38,56 @@
np.random.seed(1234)
x_gt = snp.array(discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N))

det_count = N
det_spacing = np.sqrt(2)


"""
Define CT geometry and construct array of (approximately) equivalent projectors.
"""
n_projection = 45 # number of projections
angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles
projectors = {
"astra": astra.XRayTransform2D(x_gt.shape, N, 1.0, angles - np.pi / 2.0), # astra
"svmbir": svmbir.XRayTransform(x_gt.shape, 2 * np.pi - angles, N), # svmbir
"scico": XRayTransform(Parallel2dProjector((N, N), angles, det_count=N)), # scico
"astra": astra.XRayTransform2D(
x_gt.shape, det_count, det_spacing, angles - np.pi / 2.0
), # astra
"svmbir": svmbir.XRayTransform(
x_gt.shape, 2 * np.pi - angles, det_count, delta_pixel=1.0, delta_channel=det_spacing
), # svmbir
"scico": XRayTransform(
Parallel2dProjector((N, N), angles, det_count=det_count, dx=1 / det_spacing)
), # scico
}


Michael-T-McCann marked this conversation as resolved.
Show resolved Hide resolved
"""
Compute common sinogram using svmbir projector.
"""
A = projectors["svmbir"]
noise = np.random.normal(size=(n_projection, N)).astype(np.float32)
A = projectors["astra"]
noise = np.random.normal(size=(n_projection, det_count)).astype(np.float32)
y = A @ x_gt + 2.0 * noise


"""
Construct initial solution for regularized problem.
"""
x0 = A.fbp(y)


"""
Solve the same problem using the different projectors.
"""
print(f"Solving on {device_info()}")
x_rec, hist = {}, {}
for p in ("astra", "svmbir", "scico"):
for p in projectors.keys():
print(f"\nSolving with {p} projector")

# Set up ADMM solver object.
λ = 2e0 # L1 norm regularization parameter
ρ = 5e0 # ADMM penalty parameter
maxiter = 25 # number of ADMM iterations
λ = 2e1 # L1 norm regularization parameter
ρ = 1e3 # ADMM penalty parameter
maxiter = 100 # number of ADMM iterations
cg_tol = 1e-4 # CG relative tolerance
cg_maxiter = 25 # maximum CG iterations per ADMM iteration
cg_maxiter = 50 # maximum CG iterations per ADMM iteration

# The append=0 option makes the results of horizontal and vertical
# finite differences the same shape, which is required for the L21Norm,
Expand All @@ -81,7 +96,6 @@
g = λ * functional.L21Norm()
A = projectors[p]
f = loss.SquaredL2Loss(y=y, A=A)
x0 = snp.clip(A.T(y), 0, 1.0)

# Set up the solver.
solver = ADMM(
Expand All @@ -98,7 +112,18 @@
# Run the solver.
solver.solve()
hist[p] = solver.itstat_object.history(transpose=True)
x_rec[p] = snp.clip(solver.x, 0, 1.0)
x_rec[p] = solver.x

if p == "scico":
x_rec[p] = x_rec[p] * det_spacing # to match ASTRA's scaling
bwohlberg marked this conversation as resolved.
Show resolved Hide resolved


"""
Compare reconstruction results.
"""
print("Reconstruction SNR:")
for p in projectors.keys():
print(f" {(p + ':'):7s} {metric.snr(x_gt, x_rec[p]):5.2f} dB")


"""
Expand Down Expand Up @@ -153,6 +178,8 @@
fig=fig,
ax=ax[n + 1],
)
for ax in ax:
ax.get_images()[0].set_clim(-0.1, 1.1)
fig.show()


Expand Down
73 changes: 58 additions & 15 deletions scico/linop/xray/_xray.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@

from functools import partial
from typing import Optional
from warnings import warn

import numpy as np

import jax
import jax.numpy as jnp
from jax.typing import ArrayLike

from scico.numpy.util import is_scalar_equiv
from scico.typing import Shape

from .._linop import LinearOperator
Expand Down Expand Up @@ -54,9 +56,9 @@ class Parallel2dProjector:
of each pixel to each bin because the integral of the boxcar is
simple.

By requiring the side length of the pixels to be less than or equal
to the bin width (which is assumed to be 1.0), we ensure that each
pixel contributes to at most two bins, which accelerates the
By requiring the width of a projected pixel to be less than or equal
to the bin width (which is defined to be 1.0), we ensure that
each pixel contributes to at most two bins, which accelerates the
accumulation of pixel values into bins (equivalently, makes the
linear operator sparse).

Expand All @@ -82,7 +84,7 @@ def __init__(
corresponds to summing columns, and an angle of pi/4
corresponds to summing along antidiagonals.
x0: (x, y) position of the corner of the pixel `im[0,0]`. By
default, `-im_shape / 2`.
default, `(-im_shape / 2, -im_shape / 2)`.
dx: Image pixel side length in x- and y-direction. Should be
<= 1.0 in each dimension. By default, [1.0, 1.0].
y0: Location of the edge of the first detector bin. By
Expand All @@ -94,13 +96,27 @@ def __init__(
self.angles = angles

self.nx = np.array(im_shape)
if dx is None:
dx = np.full((2,), np.sqrt(2) / 2)
if is_scalar_equiv(dx):
dx = dx * np.ones(2)
self.dx = dx

# check projected pixel width assumption
Pdx = np.stack((dx[0] * jnp.cos(angles), dx[1] * jnp.sin(angles)))
Pdiag1 = np.abs(Pdx[0] + Pdx[1])
Pdiag2 = np.abs(Pdx[0] - Pdx[1])
max_width = np.max(np.maximum(Pdiag1, Pdiag2))

if max_width > 1:
warn(
f"A projected pixel has width {max_width} > 1.0, "
"which will reduce projector accuracy."
)

if x0 is None:
x0 = -self.nx / 2
x0 = -(self.nx * self.dx) / 2
self.x0 = x0
if dx is None:
dx = np.ones(2)
self.dx = dx

if det_count is None:
det_count = int(np.ceil(np.linalg.norm(im_shape)))
Expand All @@ -112,15 +128,14 @@ def __init__(
self.y0 = y0
self.dy = 1.0

if any(self.dx > self.dy):
raise ValueError(
f"This projector assumes dx <= dy, but dx was {self.dx} and dy was {self.dy}."
)

def project(self, im):
"""Compute X-ray projection."""
return _project(im, self.x0, self.dx, self.y0, self.ny, self.angles)

def back_project(self, y):
"""Compute X-ray back projection"""
return _back_project(y, self.x0, self.dx, tuple(self.nx), self.y0, self.angles)


@partial(jax.jit, static_argnames=["ny"])
def _project(im, x0, dx, y0, ny, angles):
Expand All @@ -138,8 +153,8 @@ def _project(im, x0, dx, y0, ny, angles):
nx = im.shape
inds, weights = _calc_weights(x0, dx, nx, angles, y0)
# Handle out of bounds indices. In the .at call, inds >= y0 are
# ignored, while inds < 0 wrap around. So we set inds < 0 to y0.
inds = jnp.where(inds > 0, inds, ny)
# ignored, while inds < 0 wrap around. So we set inds < 0 to ny.
inds = jnp.where(inds >= 0, inds, ny)

y = (
jnp.zeros((len(angles), ny))
Expand All @@ -152,6 +167,34 @@ def _project(im, x0, dx, y0, ny, angles):
return y


@partial(jax.jit, static_argnames=["nx"])
def _back_project(y, x0, dx, nx, y0, angles):
r"""
Args:
y: Input projection, (num_angles, N).
x0: (x, y) position of the corner of the pixel im[0,0].
dx: Pixel side length in x- and y-direction. Units are such
that the detector bins have length 1.0.
nx: Shape of back projection.
y0: Location of the edge of the first detector bin.
angles: (num_angles,) array of angles in radians. Pixels are
projected onto units vectors pointing in these directions.
"""
ny = y.shape[1]
inds, weights = _calc_weights(x0, dx, nx, angles, y0)
# Handle out of bounds indices. In the .at call, inds >= y0 are
# ignored, while inds < 0 wrap around. So we set inds < 0 to ny.
inds = jnp.where(inds >= 0, inds, ny)

# the idea: [y[0, inds[0]], y[1, inds[1]], ...]
HTy = jnp.sum(y[jnp.arange(len(angles)).reshape(-1, 1, 1), inds] * weights, axis=0)
HTy = HTy + jnp.sum(
y[jnp.arange(len(angles)).reshape(-1, 1, 1), inds + 1] * (1 - weights), axis=0
)

return HTy


@partial(jax.jit, static_argnames=["nx", "y0"])
@partial(jax.vmap, in_axes=(None, None, None, 0, None))
def _calc_weights(x0, dx, nx, angle, y0):
Expand Down
27 changes: 26 additions & 1 deletion scico/test/linop/xray/test_xray.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,34 @@
import jax.numpy as jnp

import pytest

import scico
from scico.linop import Parallel2dProjector, XRayTransform


@pytest.mark.filterwarnings("error")
def test_init():
input_shape = (3, 3)

# no warning with default settings, even at 45 degrees
H = XRayTransform(Parallel2dProjector(input_shape, jnp.array([jnp.pi / 4])))

# no warning if we project orthogonally with oversized pixels
H = XRayTransform(Parallel2dProjector(input_shape, jnp.array([0]), dx=jnp.array([1, 1])))

# warning if the projection angle changes
with pytest.warns(UserWarning):
H = XRayTransform(
Parallel2dProjector(input_shape, jnp.array([0.1]), dx=jnp.array([1.1, 1.1]))
)

# warning if the pixels get any larger
with pytest.warns(UserWarning):
H = XRayTransform(
Parallel2dProjector(input_shape, jnp.array([0]), dx=jnp.array([1.1, 1.1]))
)


def test_apply():
im_shape = (12, 13)
num_angles = 10
Expand Down Expand Up @@ -38,7 +63,7 @@ def test_apply_adjoint():
# adjoint
bp = H.T @ y
assert scico.linop.valid_adjoint(
H, H.T, eps=1e-5
H, H.T, eps=1e-4
) # associative reductions might cause small errors, hence 1e-5

# fixed det_length
Expand Down
Loading