diff --git a/CHANGES.rst b/CHANGES.rst
index e5aa718b..eb4c3b97 100644
--- a/CHANGES.rst
+++ b/CHANGES.rst
@@ -20,7 +20,7 @@ Version 0.0.6 (unreleased)
• Rename ``scico.flax.save_weights`` and ``scico.flax.load_weights`` to
``scico.flax.save_variables`` and ``scico.flax.load_variables``
respectively.
-• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.33.
+• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.34.
• Support ``flax`` versions 0.8.0 to 0.9.0.
diff --git a/data b/data
index c1233896..1ceadbbb 160000
--- a/data
+++ b/data
@@ -1 +1 @@
-Subproject commit c12338966b1b9f92554066743b1a8b664c7b7e24
+Subproject commit 1ceadbbbe6bef9f364fd76dbda44ff6e185a7d10
diff --git a/docs/source/examples.rst b/docs/source/examples.rst
index 58ba847f..b97a52bd 100644
--- a/docs/source/examples.rst
+++ b/docs/source/examples.rst
@@ -34,12 +34,11 @@ Computed Tomography
examples/ct_svmbir_ppp_bm3d_admm_cg
examples/ct_svmbir_ppp_bm3d_admm_prox
examples/ct_fan_svmbir_ppp_bm3d_admm_prox
- examples/ct_astra_modl_train_foam2
- examples/ct_astra_odp_train_foam2
- examples/ct_astra_unet_train_foam2
+ examples/ct_modl_train_foam2
+ examples/ct_odp_train_foam2
+ examples/ct_unet_train_foam2
examples/ct_projector_comparison_2d
examples/ct_projector_comparison_3d
- examples/ct_multi_cs_tv_admm
examples/ct_multi_tv_admm
Deconvolution
@@ -96,7 +95,7 @@ Miscellaneous
examples/denoise_dncnn_universal
examples/diffusercam_tv_admm
examples/video_rpca_admm
- examples/ct_astra_datagen_foam2
+ examples/ct_datagen_foam2
examples/deconv_datagen_bsds
examples/deconv_datagen_foam1
examples/denoise_datagen_bsds
@@ -181,10 +180,10 @@ Machine Learning
.. toctree::
:maxdepth: 1
- examples/ct_astra_datagen_foam2
- examples/ct_astra_modl_train_foam2
- examples/ct_astra_odp_train_foam2
- examples/ct_astra_unet_train_foam2
+ examples/ct_datagen_foam2
+ examples/ct_modl_train_foam2
+ examples/ct_odp_train_foam2
+ examples/ct_unet_train_foam2
examples/deconv_datagen_bsds
examples/deconv_datagen_foam1
examples/deconv_modl_train_foam1
diff --git a/examples/jnb.py b/examples/jnb.py
index 9135add1..f51c3f9a 100644
--- a/examples/jnb.py
+++ b/examples/jnb.py
@@ -60,10 +60,10 @@ def py_file_to_string(src):
# Process remainder of source file
for line in srcfile:
- if re.match("^input\(", line): # end processing when input statement encountered
+ if re.match(r"^input\(", line): # end processing when input statement encountered
break
line = re.sub('^r"""', '"""', line) # remove r from r"""
- line = re.sub(":cite:\`([^`]+)\`", r'', line) # fix cite format
+ line = re.sub(r":cite:\`([^`]+)\`", r'', line) # fix cite format
lines.append(line)
# Backtrack through list of lines to remove trailing newlines
diff --git a/examples/scripts/README.rst b/examples/scripts/README.rst
index 446186a7..95e11f7c 100644
--- a/examples/scripts/README.rst
+++ b/examples/scripts/README.rst
@@ -33,11 +33,11 @@ Computed Tomography
PPP (with BM3D) CT Reconstruction (ADMM with Fast SVMBIR Prox)
`ct_fan_svmbir_ppp_bm3d_admm_prox.py `_
PPP (with BM3D) Fan-Beam CT Reconstruction
- `ct_astra_modl_train_foam2.py `_
- CT Training and Reconstructions with MoDL
- `ct_astra_odp_train_foam2.py `_
- CT Training and Reconstructions with ODP
- `ct_astra_unet_train_foam2.py `_
+ `ct_modl_train_foam2.py `_
+ CT Training and Reconstruction with MoDL
+ `ct_odp_train_foam2.py `_
+ CT Training and Reconstruction with ODP
+ `ct_unet_train_foam2.py `_
CT Training and Reconstructions with UNet
`ct_projector_comparison_2d.py `_
2D X-ray Transform Comparison
@@ -123,7 +123,7 @@ Miscellaneous
TV-Regularized 3D DiffuserCam Reconstruction
`video_rpca_admm.py `_
Video Decomposition via Robust PCA
- `ct_astra_datagen_foam2.py `_
+ `ct_datagen_foam2.py `_
CT Data Generation for NN Training
`deconv_datagen_bsds.py `_
Blurred Data Generation (Natural Images) for NN Training
@@ -239,13 +239,13 @@ Sparsity
Machine Learning
^^^^^^^^^^^^^^^^
- `ct_astra_datagen_foam2.py `_
+ `ct_datagen_foam2.py `_
CT Data Generation for NN Training
- `ct_astra_modl_train_foam2.py `_
- CT Training and Reconstructions with MoDL
- `ct_astra_odp_train_foam2.py `_
- CT Training and Reconstructions with ODP
- `ct_astra_unet_train_foam2.py `_
+ `ct_modl_train_foam2.py `_
+ CT Training and Reconstruction with MoDL
+ `ct_odp_train_foam2.py `_
+ CT Training and Reconstruction with ODP
+ `ct_unet_train_foam2.py `_
CT Training and Reconstructions with UNet
`deconv_datagen_bsds.py `_
Blurred Data Generation (Natural Images) for NN Training
diff --git a/examples/scripts/ct_astra_datagen_foam2.py b/examples/scripts/ct_datagen_foam2.py
similarity index 93%
rename from examples/scripts/ct_astra_datagen_foam2.py
rename to examples/scripts/ct_datagen_foam2.py
index 4e6fb97c..2fc9be59 100644
--- a/examples/scripts/ct_astra_datagen_foam2.py
+++ b/examples/scripts/ct_datagen_foam2.py
@@ -14,6 +14,7 @@
"""
# isort: off
+import os
import numpy as np
import logging
@@ -21,6 +22,9 @@
ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087
+# Set an arbitrary processor count (only applies if GPU is not available).
+os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
+
from scico import plot
from scico.flax.examples import load_ct_data
diff --git a/examples/scripts/ct_astra_modl_train_foam2.py b/examples/scripts/ct_modl_train_foam2.py
similarity index 96%
rename from examples/scripts/ct_astra_modl_train_foam2.py
rename to examples/scripts/ct_modl_train_foam2.py
index a06d7b81..19a3d810 100644
--- a/examples/scripts/ct_astra_modl_train_foam2.py
+++ b/examples/scripts/ct_modl_train_foam2.py
@@ -5,8 +5,8 @@
# with the package.
r"""
-CT Training and Reconstructions with MoDL
-=========================================
+CT Training and Reconstruction with MoDL
+========================================
This example demonstrates the training and application of a
model-based deep learning (MoDL) architecture described in
@@ -65,7 +65,7 @@
from scico import metric, plot
from scico.flax.examples import load_ct_data
from scico.flax.train.traversals import clip_positive, construct_traversal
-from scico.linop.xray.astra import XRayTransform2D
+from scico.linop.xray import XRayTransform2D
"""
Prepare parallel processing. Set an arbitrary processor count (only
@@ -89,16 +89,17 @@
"""
-Build CT projection operator.
+Build CT projection operator. Parameters are chosen so that the operator
+is equivalent to the one used to generate the training data.
"""
angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles
A = XRayTransform2D(
input_shape=(N, N),
- det_spacing=1,
- det_count=N,
angles=angles,
-) # CT projection operator
-A = (1.0 / N) * A # normalized
+ det_count=int(N * 1.05 / np.sqrt(2.0)),
+ dx=1.0 / np.sqrt(2),
+)
+A = (1.0 / N) * A # normalize projection operator
"""
diff --git a/examples/scripts/ct_multi_tv_admm.py b/examples/scripts/ct_multi_tv_admm.py
index 95711679..f2f13fd8 100644
--- a/examples/scripts/ct_multi_tv_admm.py
+++ b/examples/scripts/ct_multi_tv_admm.py
@@ -38,7 +38,7 @@
np.random.seed(1234)
x_gt = snp.array(discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N))
-det_count = N
+det_count = int(N * 1.05 / np.sqrt(2.0))
det_spacing = np.sqrt(2)
diff --git a/examples/scripts/ct_astra_odp_train_foam2.py b/examples/scripts/ct_odp_train_foam2.py
similarity index 94%
rename from examples/scripts/ct_astra_odp_train_foam2.py
rename to examples/scripts/ct_odp_train_foam2.py
index 03753c1b..e5cd58ae 100644
--- a/examples/scripts/ct_astra_odp_train_foam2.py
+++ b/examples/scripts/ct_odp_train_foam2.py
@@ -5,8 +5,8 @@
# with the package.
r"""
-CT Training and Reconstructions with ODP
-========================================
+CT Training and Reconstruction with ODP
+=======================================
This example demonstrates the training of the unrolled optimization with
deep priors (ODP) gradient descent architecture described in
@@ -72,7 +72,7 @@
from scico import metric, plot
from scico.flax.examples import load_ct_data
from scico.flax.train.traversals import clip_positive, construct_traversal
-from scico.linop.xray.astra import XRayTransform2D
+from scico.linop.xray import XRayTransform2D
platform = get_backend().platform
@@ -92,21 +92,22 @@
"""
-Build CT projection operator.
+Build CT projection operator. Parameters are chosen so that the operator
+is equivalent to the one used to generate the training data.
"""
angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles
A = XRayTransform2D(
input_shape=(N, N),
- det_spacing=1,
- det_count=N,
angles=angles,
-) # CT projection operator
-A = (1.0 / N) * A # normalized
+ det_count=int(N * 1.05 / np.sqrt(2.0)),
+ dx=1.0 / np.sqrt(2),
+)
+A = (1.0 / N) * A # normalize projection operator
"""
Build training and testing structures. Inputs are the sinograms and
-outpus are the original generated foams. Keep training and testing
+outputs are the original generated foams. Keep training and testing
partitions.
"""
numtr = 320
diff --git a/examples/scripts/ct_projector_comparison_2d.py b/examples/scripts/ct_projector_comparison_2d.py
index 2e5d02d3..0a47c7b0 100644
--- a/examples/scripts/ct_projector_comparison_2d.py
+++ b/examples/scripts/ct_projector_comparison_2d.py
@@ -29,9 +29,6 @@
Create a ground truth image.
"""
N = 512
-
-det_count = int(jnp.ceil(jnp.sqrt(2 * N**2)))
-
x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)
x_gt = jnp.array(x_gt)
@@ -41,17 +38,18 @@
"""
num_angles = 500
angles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False)
+det_count = int(N * 1.02 / jnp.sqrt(2.0))
timer = Timer()
projectors = {}
timer.start("scico_init")
-projectors["scico"] = XRayTransform2D((N, N), angles)
+projectors["scico"] = XRayTransform2D((N, N), angles, det_count=det_count)
timer.stop("scico_init")
timer.start("astra_init")
projectors["astra"] = astra.XRayTransform2D(
- (N, N), det_count=det_count, det_spacing=1.0, angles=angles - jnp.pi / 2.0
+ (N, N), det_count=det_count, det_spacing=np.sqrt(2), angles=angles - jnp.pi / 2.0
)
timer.stop("astra_init")
diff --git a/examples/scripts/ct_astra_unet_train_foam2.py b/examples/scripts/ct_unet_train_foam2.py
similarity index 100%
rename from examples/scripts/ct_astra_unet_train_foam2.py
rename to examples/scripts/ct_unet_train_foam2.py
diff --git a/examples/scripts/index.rst b/examples/scripts/index.rst
index 4f05ba2f..e36e9fbd 100644
--- a/examples/scripts/index.rst
+++ b/examples/scripts/index.rst
@@ -21,9 +21,9 @@ Computed Tomography
- ct_svmbir_ppp_bm3d_admm_cg.py
- ct_svmbir_ppp_bm3d_admm_prox.py
- ct_fan_svmbir_ppp_bm3d_admm_prox.py
- - ct_astra_modl_train_foam2.py
- - ct_astra_odp_train_foam2.py
- - ct_astra_unet_train_foam2.py
+ - ct_modl_train_foam2.py
+ - ct_odp_train_foam2.py
+ - ct_unet_train_foam2.py
- ct_projector_comparison_2d.py
- ct_projector_comparison_3d.py
- ct_multi_tv_admm.py
@@ -73,7 +73,7 @@ Miscellaneous
- denoise_dncnn_universal.py
- diffusercam_tv_admm.py
- video_rpca_admm.py
- - ct_astra_datagen_foam2.py
+ - ct_datagen_foam2.py
- deconv_datagen_bsds.py
- deconv_datagen_foam1.py
- denoise_datagen_bsds.py
@@ -143,10 +143,10 @@ Sparsity
Machine Learning
^^^^^^^^^^^^^^^^
- - ct_astra_datagen_foam2.py
- - ct_astra_modl_train_foam2.py
- - ct_astra_odp_train_foam2.py
- - ct_astra_unet_train_foam2.py
+ - ct_datagen_foam2.py
+ - ct_modl_train_foam2.py
+ - ct_odp_train_foam2.py
+ - ct_unet_train_foam2.py
- deconv_datagen_bsds.py
- deconv_datagen_foam1.py
- deconv_modl_train_foam1.py
diff --git a/misc/conda/install_conda.sh b/misc/conda/install_conda.sh
index 73defcaa..744d1836 100755
--- a/misc/conda/install_conda.sh
+++ b/misc/conda/install_conda.sh
@@ -97,7 +97,6 @@ rm -f /tmp/miniconda.sh
export PATH="$CONDAHOME/bin:$PATH"
hash -r
conda config --set always_yes yes
-conda install mamba -n base -c conda-forge
conda update -q conda
conda info -a
diff --git a/misc/conda/make_conda_env.sh b/misc/conda/make_conda_env.sh
index b5aa4411..cab2b7e2 100755
--- a/misc/conda/make_conda_env.sh
+++ b/misc/conda/make_conda_env.sh
@@ -50,7 +50,7 @@ EOF
)
# Requirements that cannot be installed via conda (i.e. have to use pip)
NOCONDA=$(cat <<-EOF
-flax bm3d bm4d py2jn colour_demosaicing hyperopt ray[tune,train]
+flax orbax-checkpoint bm3d bm4d py2jn colour_demosaicing hyperopt ray[tune,train]
EOF
)
@@ -217,19 +217,16 @@ eval "$(conda shell.bash hook)" # required to avoid errors re: `conda init`
conda activate $ENVNM # Q: why not `source activate`? A: not always in the path
# Add conda-forge channel
-conda config --env --append channels conda-forge
-
-# Install mamba
-conda install mamba -n base -c conda-forge
+conda config --append channels conda-forge
# Install required conda packages (and extra useful packages)
-mamba install $CONDA_FLAGS $CONDAREQ ipython
+conda install $CONDA_FLAGS $CONDAREQ ipython
# Utility ffmpeg is required by imageio for reading mp4 video files
# it can also be installed via the system package manager, .e.g.
# sudo apt install ffmpeg
if [ "$(which ffmpeg)" = '' ]; then
- mamba install $CONDA_FLAGS ffmpeg
+ conda install $CONDA_FLAGS ffmpeg
fi
# Install jaxlib and jax
diff --git a/requirements.txt b/requirements.txt
index 1b0e5359..feab852c 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -4,8 +4,8 @@ scipy>=1.6.0
imageio>=2.17
tifffile
matplotlib
-jaxlib>=0.4.3,<=0.4.33
-jax>=0.4.3,<=0.4.33
+jaxlib>=0.4.3,<=0.4.34
+jax>=0.4.3,<=0.4.34
orbax-checkpoint>=0.5.0
flax>=0.8.0,<=0.9.0
pyabel>=0.9.0
diff --git a/scico/flax/examples/data_generation.py b/scico/flax/examples/data_generation.py
index e7a99030..e89f10cf 100644
--- a/scico/flax/examples/data_generation.py
+++ b/scico/flax/examples/data_generation.py
@@ -51,16 +51,9 @@ class UnitCircle:
from jax.lib.xla_bridge import get_backend
from scico.linop import CircularConvolve
+from scico.linop.xray import XRayTransform2D
from scico.numpy import Array
-try:
- import astra # noqa: F401
-except ImportError:
- have_astra = False
-else:
- have_astra = True
- from scico.linop.xray.astra import XRayTransform2D
-
class Foam2(UnitCircle):
"""Foam-like material with two attenuations.
@@ -218,10 +211,8 @@ def generate_ct_data(
- **sino** : (:class:`jax.Array`): Corresponding sinograms.
- **fbp** : (:class:`jax.Array`) Corresponding filtered back projections.
"""
- if not (have_ray and have_xdesign and have_astra):
- raise RuntimeError(
- "Packages ray, xdesign, and astra are required for use of this function."
- )
+ if not (have_ray and have_xdesign):
+ raise RuntimeError("Packages ray and xdesign are required for use of this function.")
# Generate input data.
start_time = time()
@@ -234,17 +225,17 @@ def generate_ct_data(
# Configure a CT projection operator to generate synthetic measurements.
angles = np.linspace(0, jnp.pi, nproj) # evenly spaced projection angles
- gt_sh = (size, size)
- detector_spacing = 1.0
- A = XRayTransform2D(gt_sh, size, detector_spacing, angles) # X-ray transform operator
-
+ gt_shape = (size, size)
+ dx = 1.0 / np.sqrt(2)
+ det_count = int(size * 1.05 / np.sqrt(2.0))
+ A = XRayTransform2D(gt_shape, angles, dx=dx, det_count=det_count)
# Compute sinograms in parallel.
start_time = time()
if nproc > 1:
# shard array
imgshd = img.reshape((nproc, -1, size, size, 1))
sinoshd = batched_f(A, imgshd)
- sino = sinoshd.reshape((-1, nproj, size, 1))
+ sino = sinoshd.reshape((-1, nproj, sinoshd.shape[-2], 1))
else:
sino = vector_f(A, img)
@@ -261,8 +252,8 @@ def generate_ct_data(
# Normalize sinogram.
sino = sino / size
- # Shift FBP to [0,1] range.
- fbp = (fbp - fbp.min()) / (fbp.max() - fbp.min())
+ # Clip FBP to [0,1] range.
+ fbp = np.clip(fbp, 0, 1)
if verbose: # pragma: no cover
platform = get_backend().platform
diff --git a/scico/functional/_functional.py b/scico/functional/_functional.py
index 256791a7..1cbef6fc 100644
--- a/scico/functional/_functional.py
+++ b/scico/functional/_functional.py
@@ -117,7 +117,7 @@ def conj_prox(
return v - lam * self.prox(v / lam, 1.0 / lam, **kwargs)
def grad(self, x: Union[Array, BlockArray]):
- r"""Evaluates the gradient of this functional at :math:`\mb{x}`.
+ r"""Evaluate the gradient of this functional at :math:`\mb{x}`.
Args:
x: Point at which to evaluate gradient.
diff --git a/scico/linop/xray/_xray.py b/scico/linop/xray/_xray.py
index 83bd8462..9d459db5 100644
--- a/scico/linop/xray/_xray.py
+++ b/scico/linop/xray/_xray.py
@@ -9,7 +9,7 @@
from functools import partial
-from typing import Optional
+from typing import Optional, Tuple
from warnings import warn
import numpy as np
@@ -273,14 +273,14 @@ def _back_project(
y[jnp.arange(len(angles)).reshape(-1, 1, 1), inds + 1] * weights_valid, axis=0
)
- return HTy
+ return HTy.astype(jnp.float32)
@staticmethod
@partial(jax.jit, static_argnames=["nx"])
@partial(jax.vmap, in_axes=(None, None, None, 0, None))
def _calc_weights(
- x0: ArrayLike, dx: ArrayLike, nx: Shape, angle: float, y0: float
- ) -> snp.Array:
+ x0: ArrayLike, dx: ArrayLike, nx: Shape, angles: ArrayLike, y0: float
+ ) -> Tuple[snp.Array, snp.Array]:
"""
Args:
@@ -288,12 +288,12 @@ def _calc_weights(
dx: Pixel side length in x- and y-direction. Units are such
that the detector bins have length 1.0.
nx: Input image shape.
- angle: (num_angles,) array of angles in radians. Pixels are
+ angles: (num_angles,) array of angles in radians. Pixels are
projected onto units vectors pointing in these directions.
(This argument is `vmap`ed.)
y0: Location of the edge of the first detector bin.
"""
- u = [jnp.cos(angle), jnp.sin(angle)]
+ u = [jnp.cos(angles), jnp.sin(angles)]
Px0 = x0[0] * u[0] + x0[1] * u[1] - y0
Pdx = [dx[0] * u[0], dx[1] * u[1]]
Pxmin = jnp.min(jnp.array([Px0, Px0 + Pdx[0], Px0 + Pdx[1], Px0 + Pdx[0] + Pdx[1]]))
diff --git a/scico/plot.py b/scico/plot.py
index 6c0375be..72ae4962 100644
--- a/scico/plot.py
+++ b/scico/plot.py
@@ -13,6 +13,7 @@
# This module is copied from https://github.com/bwohlberg/sporco
+import os
import sys
import numpy as np
@@ -820,7 +821,8 @@ def config_notebook_plotting():
Configure plotting functions for inline plotting within a Jupyter
Notebook shell. This function has no effect when not within a
notebook shell, and may therefore be used within a normal python
- script.
+ script. If environment variable ``MATPLOTLIB_IPYNB_BACKEND`` is set,
+ the matplotlib backend is explicitly set to the specified value.
"""
# Check whether running within a notebook shell and have
@@ -828,8 +830,9 @@ def config_notebook_plotting():
module = sys.modules[__name__]
if _in_notebook() and module.plot.__name__ == "plot":
- # Set inline backend (i.e. %matplotlib inline) if in a notebook shell
- set_notebook_plot_backend()
+ # Set backend if specified by environment variable
+ if "MATPLOTLIB_IPYNB_BACKEND" in os.environ:
+ set_notebook_plot_backend(os.environ["MATPLOTLIB_IPYNB_BACKEND"])
# Replace plot function with a wrapper function that discards
# its return value (within a notebook with inline plotting, plots
diff --git a/scico/test/flax/test_examples_flax.py b/scico/test/flax/test_examples_flax.py
index 72c084dd..ce265d05 100644
--- a/scico/test/flax/test_examples_flax.py
+++ b/scico/test/flax/test_examples_flax.py
@@ -12,7 +12,6 @@
generate_ct_data,
generate_foam1_images,
generate_foam2_images,
- have_astra,
have_ray,
have_xdesign,
)
@@ -75,8 +74,8 @@ def random_data_gen(seed, N, ndata):
@pytest.mark.skipif(
- not have_astra or not have_ray or not have_xdesign,
- reason="astra, ray, or xdesign package not installed",
+ not have_ray or not have_xdesign,
+ reason="ray or xdesign package not installed",
)
def test_ct_data_generation():
N = 32
@@ -90,7 +89,7 @@ def random_img_gen(seed, size, ndata):
img, sino, fbp = generate_ct_data(nimg, N, nproj, imgfunc=random_img_gen)
assert img.shape == (nimg, N, N, 1)
- assert sino.shape == (nimg, nproj, N, 1)
+ assert sino.shape == (nimg, nproj, sino.shape[2], 1)
assert fbp.shape == (nimg, N, N, 1)
diff --git a/scico/test/flax/test_inv.py b/scico/test/flax/test_inv.py
index b63a88a8..d49df64b 100644
--- a/scico/test/flax/test_inv.py
+++ b/scico/test/flax/test_inv.py
@@ -6,18 +6,12 @@
import jax.numpy as jnp
from jax import lax
-import pytest
-
from scico import flax as sflax
from scico import random
from scico.flax.examples import PaddedCircularConvolve, build_blur_kernel
-from scico.flax.examples.data_generation import have_astra
from scico.flax.train.traversals import clip_positive, clip_range, construct_traversal
from scico.linop import CircularConvolve, Identity
-
-if have_astra:
- from scico.linop.xray.astra import XRayTransform2D
-
+from scico.linop.xray import XRayTransform2D
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
@@ -153,7 +147,6 @@ def test_train_odpdcnv_default(self):
np.testing.assert_array_less(1e-2 * np.ones(alphaval.shape), alphaval)
-@pytest.mark.skipif(not have_astra, reason="astra package not installed")
class TestCT:
def setup_method(self, method):
self.N = 32 # signal size
@@ -162,11 +155,10 @@ def setup_method(self, method):
xt, key = random.randn((2 * self.bsize, self.N, self.N, self.chn), seed=4321)
self.nproj = 60 # number of projections
- angles = np.linspace(0, np.pi, self.nproj) # evenly spaced projection angles
+ angles = np.linspace(0, np.pi, self.nproj, endpoint=False, dtype=np.float32)
self.opCT = XRayTransform2D(
input_shape=(self.N, self.N),
det_count=self.N,
- det_spacing=1.0,
angles=angles,
) # Radon transform operator
a_f = lambda v: jnp.atleast_3d(self.opCT(v.squeeze()))