diff --git a/scico/test/linop/xray/test_xray.py b/scico/test/linop/xray/test_xray.py index 44e8c05b..9efddffa 100644 --- a/scico/test/linop/xray/test_xray.py +++ b/scico/test/linop/xray/test_xray.py @@ -1,5 +1,6 @@ import numpy as np +import jax import jax.numpy as jnp import pytest @@ -76,8 +77,9 @@ def test_apply_adjoint(): @pytest.mark.parametrize("det_count_factor", [1.02 / np.sqrt(2.0), 1.0]) def test_fbp(dx, det_count_factor): N = 256 - x_gt = np.zeros((256, 256), dtype=np.float32) - x_gt[64:-64, 64:-64] = 1.0 + x_gt = np.zeros((N, N), dtype=np.float32) + N4 = N // 4 + x_gt[N4:-N4, N4:-N4] = 1.0 det_count = int(det_count_factor * N) n_proj = 360 @@ -88,6 +90,19 @@ def test_fbp(dx, det_count_factor): assert psnr(x_gt, x_fbp) > 28 +def test_fbp_jit(): + N = 64 + x_gt = np.ones((N, N), dtype=np.float32) + + det_count = N + n_proj = 90 + angles = np.linspace(0, np.pi, n_proj, endpoint=False) + A = XRayTransform2D(x_gt.shape, angles, det_count=det_count) + y = A(x_gt) + fbp = jax.jit(A.fbp) + x_fbp = fbp(y) + + def test_3d_scaling(): x = jnp.zeros((4, 4, 1)) x = x.at[1:3, 1:3, 0].set(1.0)