Skip to content

Commit

Permalink
Add manual back projector
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael-T-McCann committed Jul 22, 2024
1 parent b95e398 commit c8fcbfe
Showing 1 changed file with 43 additions and 4 deletions.
47 changes: 43 additions & 4 deletions scico/linop/xray/_xray.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,6 @@ def project(self, im):
"""Compute X-ray projection."""
return Parallel3dProjector._project(im, self.matrices, self.det_shape)

def back_project(self, proj):
"""Compute X-ray back projection"""

@staticmethod
def _project(im: ArrayLike, matrices: ArrayLike, det_shape: Shape) -> ArrayLike:
r"""
Expand Down Expand Up @@ -305,7 +302,7 @@ def _project_single(
det_shape: Shape of detector.
"""

ul_ind, ul_weight, ur_weight, ll_weight, lr_weight = _calc_weights(
ul_ind, ul_weight, ur_weight, ll_weight, lr_weight = Parallel3dProjector._calc_weights(
im.shape, matrix, proj.shape, slice_offset
)
proj = proj.at[ul_ind[0], ul_ind[1]].add(ul_weight * im, mode="drop")
Expand All @@ -314,6 +311,48 @@ def _project_single(
proj = proj.at[ul_ind[0] + 1, ul_ind[1] + 1].add(lr_weight * im, mode="drop")
return proj

def back_project(self, proj):
"""Compute X-ray back projection"""
return Parallel3dProjector._back_project(proj, self.matrices, self.input_shape)

@staticmethod
def _back_project(proj: ArrayLike, matrices: ArrayLike, input_shape: Shape) -> ArrayLike:
r"""
Args:
proj: Input (set of) projection(s).
matrix: (num_views, 2, 4) array of homogeneous projection matrices.
input_shape: Shape of desired back projection.
"""
MAX_SLICE_LEN = 10
slice_offsets = list(range(0, input_shape[0], MAX_SLICE_LEN))

HTy = jnp.zeros(input_shape, dtype=proj.dtype)
for view_ind, matrix in enumerate(matrices):
for slice_offset in slice_offsets:
HTy = HTy.at[slice_offset : slice_offset + MAX_SLICE_LEN].set(
Parallel3dProjector._back_project_single(
proj[view_ind],
matrix,
HTy[slice_offset : slice_offset + MAX_SLICE_LEN],
slice_offset=slice_offset,
)
)
return HTy

@staticmethod
@partial(jax.jit, donate_argnames="HTy")
def _back_project_single(
y: ArrayLike, matrix: ArrayLike, HTy: ArrayLike, slice_offset: int = 0
) -> ArrayLike:
ul_ind, ul_weight, ur_weight, ll_weight, lr_weight = Parallel3dProjector._calc_weights(
HTy.shape, matrix, y.shape, slice_offset
)
HTy = HTy + y[ul_ind[0], ul_ind[1]] * ul_weight
HTy = HTy + y[ul_ind[0] + 1, ul_ind[1]] * ur_weight
HTy = HTy + y[ul_ind[0], ul_ind[1] + 1] * ll_weight
HTy = HTy + y[ul_ind[0] + 1, ul_ind[1] + 1] * lr_weight
return HTy

@staticmethod
def _calc_weights(input_shape, matrix, output_shape, slice_offset: int = 0):
# pixel (0, 0, 0) has its center at (0.5, 0.5, 0.5)
Expand Down

0 comments on commit c8fcbfe

Please sign in to comment.