Skip to content

Commit

Permalink
Replace jax.experimental.host_callback.call with jax.pure_callback
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg committed Dec 15, 2023
1 parent c170159 commit 9e7de57
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 27 deletions.
18 changes: 13 additions & 5 deletions scico/denoiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import numpy as np

from jax.experimental import host_callback as hcb
import jax

try:
import bm3d as tubm3d
Expand Down Expand Up @@ -84,7 +84,7 @@ def bm3d_eval(x: snp.Array, sigma: float):
"BM3D requires two-dimensional or three dimensional inputs; got ndim = {x.ndim}."
)

# This check is also performed inside the BM3D call, but due to the host_callback,
# This check is also performed inside the BM3D call, but due to the callback,
# no exception is raised and the program will crash with no traceback.
# NOTE: if BM3D is extended to allow for different profiles, the block size must be
# updated; this presumes 'np' profile (bs=8)
Expand All @@ -103,7 +103,11 @@ def bm3d_eval(x: snp.Array, sigma: float):
" the additional axes are singletons."
)

y = hcb.call(lambda args: bm3d_eval(*args).astype(x.dtype), (x, sigma), result_shape=x)
y = jax.pure_callback(
lambda args: bm3d_eval(*args).astype(x.dtype),
jax.ShapeDtypeStruct(x.shape, x.dtype),
(x, sigma),
)

# undo squeezing, if neccessary
y = y.reshape(x_in_shape)
Expand Down Expand Up @@ -145,7 +149,7 @@ def bm4d_eval(x: snp.Array, sigma: float):
if isinstance(x.ndim, tuple) or x.ndim < 3:
raise ValueError(f"BM4D requires three-dimensional inputs; got ndim = {x.ndim}.")

# This check is also performed inside the BM4D call, but due to the host_callback,
# This check is also performed inside the BM4D call, but due to the callback,
# no exception is raised and the program will crash with no traceback.
# NOTE: if BM4D is extended to allow for different profiles, the block size must be
# updated; this presumes 'np' profile (bs=8)
Expand All @@ -164,7 +168,11 @@ def bm4d_eval(x: snp.Array, sigma: float):
" the additional axes are singletons."
)

y = hcb.call(lambda args: bm4d_eval(*args).astype(x.dtype), (x, sigma), result_shape=x)
y = jax.pure_callback(
lambda args: bm4d_eval(*args).astype(x.dtype),
jax.ShapeDtypeStruct(x.shape, x.dtype),
(x, sigma),
)

# undo squeezing, if neccessary
y = y.reshape(x_in_shape)
Expand Down
10 changes: 5 additions & 5 deletions scico/flax/inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,11 @@ def cg_solver(A: Callable, b: Array, x0: Array = None, maxiter: int = 50) -> Arr
version constructed to be differentiable with the autograd
functionality from jax. Therefore, (i) it uses :meth:`jax.lax.scan`
to execute a fixed number of iterations and (ii) it assumes that the
linear operator may use :meth:`jax.experimental.host_callback`. Due
to the utilization of a while cycle, :meth:`scico.cg` is not
differentiable by jax and :meth:`jax.scipy.sparse.linalg.cg` does not
support functions using :meth:`jax.experimental.host_callback`
explaining why an additional conjugate gradient function is implemented.
linear operator may use :meth:`jax.pure_callback`. Due to the
utilization of a while cycle, :meth:`scico.cg` is not differentiable
by jax and :meth:`jax.scipy.sparse.linalg.cg` does not support
functions using :meth:`jax.pure_callback`, which is why an additional
conjugate gradient function has been implemented.
Args:
A: Function implementing linear operator :math:`A`, should be
Expand Down
11 changes: 3 additions & 8 deletions scico/linop/xray/astra.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import numpy as np

import jax
import jax.experimental.host_callback as hcb

try:
import astra
Expand Down Expand Up @@ -196,9 +195,7 @@ def f(x):
astra.data3d.delete(proj_id)
return result

return hcb.call(
f, x, result_shape=jax.ShapeDtypeStruct(self.output_shape, self.output_dtype)
)
return jax.pure_callback(f, jax.ShapeDtypeStruct(self.output_shape, self.output_dtype), x)

def _bproj(self, y: jax.Array) -> jax.Array:
# applies backprojector
Expand All @@ -215,7 +212,7 @@ def f(y):
astra.data3d.delete(proj_id)
return result

return hcb.call(f, y, result_shape=jax.ShapeDtypeStruct(self.input_shape, self.input_dtype))
return jax.pure_callback(f, jax.ShapeDtypeStruct(self.input_shape, self.input_dtype), y)

def fbp(self, sino: jax.Array, filter_type: str = "Ram-Lak") -> jax.Array:
"""Filtered back projection (FBP) reconstruction.
Expand Down Expand Up @@ -262,6 +259,4 @@ def f(sino):
astra.data2d.delete(sino_id)
return out

return hcb.call(
f, sino, result_shape=jax.ShapeDtypeStruct(self.input_shape, self.input_dtype)
)
return jax.pure_callback(f, jax.ShapeDtypeStruct(self.input_shape, self.input_dtype), sino)
15 changes: 6 additions & 9 deletions scico/linop/xray/svmbir.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import numpy as np

import jax
import jax.experimental.host_callback

import scico.numpy as snp
from scico.loss import Loss, SquaredL2Loss
Expand Down Expand Up @@ -175,7 +174,6 @@ def __init__(
self.delta_pixel = delta_pixel

elif self.geometry == "parallel":

self.magnification = 1.0
if delta_pixel is None:
self.delta_pixel = self.delta_channel
Expand Down Expand Up @@ -232,8 +230,8 @@ def _proj(

def _proj_hcb(self, x):
x = x.reshape(self.svmbir_input_shape)
# host callback wrapper for _proj
y = jax.experimental.host_callback.call(
# callback wrapper for _proj
y = jax.pure_callback(
lambda x: self._proj(
x,
self.angles,
Expand All @@ -246,8 +244,8 @@ def _proj_hcb(self, x):
delta_channel=self.delta_channel,
delta_pixel=self.delta_pixel,
),
jax.ShapeDtypeStruct(self.svmbir_output_shape, self.output_dtype),
x,
result_shape=jax.ShapeDtypeStruct(self.svmbir_output_shape, self.output_dtype),
)
return y.reshape(self.output_shape)

Expand Down Expand Up @@ -284,8 +282,8 @@ def _bproj(

def _bproj_hcb(self, y):
y = y.reshape(self.svmbir_output_shape)
# host callback wrapper for _bproj
x = jax.experimental.host_callback.call(
# callback wrapper for _bproj
x = jax.pure_callback(
lambda y: self._bproj(
y,
self.angles,
Expand All @@ -299,8 +297,8 @@ def _bproj_hcb(self, y):
delta_channel=self.delta_channel,
delta_pixel=self.delta_pixel,
),
jax.ShapeDtypeStruct(self.svmbir_input_shape, self.input_dtype),
y,
result_shape=jax.ShapeDtypeStruct(self.svmbir_input_shape, self.input_dtype),
)
return x.reshape(self.input_shape)

Expand Down Expand Up @@ -389,7 +387,6 @@ def __init__(
raise TypeError(f"Parameter W must be None or a linop.Diagonal, got {type(W)}.")

def __call__(self, x: snp.Array) -> float:

if self.positivity and snp.sum(x < 0) > 0:
return snp.inf
else:
Expand Down

0 comments on commit 9e7de57

Please sign in to comment.