Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch ML examples to integrated X-ray transform #562

Merged
merged 70 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
9a48062
Bump maximum jaxlib/jax version and resolve some errors and warnings
bwohlberg Sep 19, 2024
7e82aa1
Handle get_backend change in jax 0.4.33
bwohlberg Sep 24, 2024
97eae59
Update checkpoint interface
crstngc Sep 24, 2024
27e6aa8
Remove unserializable partial functool from checkpoint
crstngc Sep 24, 2024
9fd4df5
Update orbax-checkpoint version constraint
bwohlberg Sep 25, 2024
485a651
Merge branch 'main' into brendt/jax-version-update
bwohlberg Sep 25, 2024
3a2b29a
Switch from astra to scico 2D projector
bwohlberg Sep 25, 2024
a0ec8db
Renamed MoDL and ODP CT examples
bwohlberg Sep 25, 2024
c043d01
Minor edits
bwohlberg Sep 25, 2024
50589ff
Bug fix
bwohlberg Sep 25, 2024
363e7c2
Choose parameters so that astra and scico projectors are equivalent
bwohlberg Sep 25, 2024
f3ce9d2
Better choice of det_count parameter
bwohlberg Sep 25, 2024
347d229
Fix test
bwohlberg Sep 25, 2024
8687bd0
Improve installation instructions
bwohlberg Sep 27, 2024
cfb7aaa
Improve warning in DnCNN function prox docs
bwohlberg Sep 27, 2024
5a29ef3
Typo fix
bwohlberg Sep 27, 2024
cf360bd
Bump copyright year
bwohlberg Sep 27, 2024
3dd11d7
Resolve typing error
bwohlberg Sep 27, 2024
dfdab90
Merge branch 'brendt/jax-version-update' into brendt/de-astra-fy
bwohlberg Sep 27, 2024
2a64c96
Merge branch 'main' into brendt/de-astra-fy
bwohlberg Oct 2, 2024
4d6c632
Add filtered back projection for 2D projector
bwohlberg Oct 5, 2024
1497ee6
Update change summary
bwohlberg Oct 5, 2024
97f3b05
Docstring fixes
bwohlberg Oct 5, 2024
ff1e235
Resolve errors in jitting method
bwohlberg Oct 6, 2024
9a1e217
Merge branch 'brendt/fbp' into brendt/de-astra-fy-extended
bwohlberg Oct 8, 2024
00e3cb6
Switch to scico projector for CT training data generation
bwohlberg Oct 8, 2024
30071df
Rename example
bwohlberg Oct 8, 2024
3772c50
Rename example
bwohlberg Oct 8, 2024
050fb87
Add conditional in case of prior ray.init
bwohlberg Oct 8, 2024
16ce866
Bug fix
bwohlberg Oct 8, 2024
b2c87ef
Trivial edit
bwohlberg Oct 8, 2024
68ff3b5
Bug fix
bwohlberg Oct 8, 2024
28b0a61
Improve consistency with similar examples
bwohlberg Oct 8, 2024
8d75256
Bug fix
bwohlberg Oct 8, 2024
41a09ae
Typo fix
bwohlberg Oct 8, 2024
ad3e59d
Remove astra import test
bwohlberg Oct 8, 2024
69a0525
Bug fix
bwohlberg Oct 8, 2024
5d37894
Some improvements
bwohlberg Oct 9, 2024
36e68c7
Merge branch 'brendt/fbp' into brendt/de-astra-fy-extended
bwohlberg Oct 9, 2024
e3aaf85
Update example index
bwohlberg Oct 11, 2024
eecad71
Update secondary indices
bwohlberg Oct 11, 2024
d6fc698
Update submodule
bwohlberg Oct 11, 2024
929aafb
Recent version of orbax-checkpoint not available via conda
bwohlberg Oct 11, 2024
b69e2eb
Remove mamba
bwohlberg Oct 11, 2024
6fbd0eb
Minor change
bwohlberg Oct 11, 2024
0fe5a14
Remove mamba
bwohlberg Oct 11, 2024
fc4fd48
Fix string syntax
bwohlberg Oct 11, 2024
b831694
Remove code that breaks notebook generation script
bwohlberg Oct 11, 2024
5ce5cd8
Bump jaxlib/jax max version
bwohlberg Oct 11, 2024
81d935a
Fix script
bwohlberg Oct 11, 2024
b0b86fc
Change default matplotlib backend selection
bwohlberg Oct 13, 2024
9ea0deb
Update submodule
bwohlberg Oct 13, 2024
f49430d
Clean up
bwohlberg Oct 14, 2024
5c9c974
Improve tests
bwohlberg Oct 15, 2024
e468e4b
Merge branch 'main' into brendt/fbp
bwohlberg Oct 15, 2024
250b00f
Merge branch 'main' into brendt/de-astra-fy-extended
bwohlberg Oct 15, 2024
e485390
Merge branch 'main' into brendt/fbp
bwohlberg Oct 15, 2024
b6bded8
Improve mask mechanism
bwohlberg Oct 15, 2024
62f0a50
Improve docs
bwohlberg Oct 15, 2024
46a6f6b
Merge branch 'brendt/fbp' into brendt/de-astra-fy-extended
bwohlberg Oct 15, 2024
d53007b
Minor docs edit
bwohlberg Oct 16, 2024
9781f22
Merge branch 'main' into brendt/de-astra-fy-extended
bwohlberg Oct 18, 2024
ce98fdf
Update submodule
Oct 18, 2024
7ddc41f
Update submodule
bwohlberg Oct 18, 2024
ab5588a
Update submodule
bwohlberg Oct 18, 2024
c4a7a49
Typing fixes
bwohlberg Oct 18, 2024
38ba3ae
Avoid mismatch with declared linop dtype
bwohlberg Oct 18, 2024
7c5ca89
Address review comment
bwohlberg Oct 21, 2024
a535919
Another search/replace error fix
bwohlberg Oct 22, 2024
d595214
Update submodule
bwohlberg Oct 23, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Version 0.0.6 (unreleased)
• Rename ``scico.flax.save_weights`` and ``scico.flax.load_weights`` to
``scico.flax.save_variables`` and ``scico.flax.load_variables``
respectively.
• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.33.
• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.34.
• Support ``flax`` versions 0.8.0 to 0.9.0.


Expand Down
17 changes: 8 additions & 9 deletions docs/source/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,11 @@ Computed Tomography
examples/ct_svmbir_ppp_bm3d_admm_cg
examples/ct_svmbir_ppp_bm3d_admm_prox
examples/ct_fan_svmbir_ppp_bm3d_admm_prox
examples/ct_astra_modl_train_foam2
examples/ct_astra_odp_train_foam2
examples/ct_astra_unet_train_foam2
examples/ct_modl_train_foam2
examples/ct_odp_train_foam2
examples/ct_unet_train_foam2
examples/ct_projector_comparison_2d
examples/ct_projector_comparison_3d
examples/ct_multi_cs_tv_admm
examples/ct_multi_tv_admm

Deconvolution
Expand Down Expand Up @@ -96,7 +95,7 @@ Miscellaneous
examples/denoise_dncnn_universal
examples/diffusercam_tv_admm
examples/video_rpca_admm
examples/ct_astra_datagen_foam2
examples/ct_datagen_foam2
examples/deconv_datagen_bsds
examples/deconv_datagen_foam1
examples/denoise_datagen_bsds
Expand Down Expand Up @@ -181,10 +180,10 @@ Machine Learning
.. toctree::
:maxdepth: 1

examples/ct_astra_datagen_foam2
examples/ct_astra_modl_train_foam2
examples/ct_astra_odp_train_foam2
examples/ct_astra_unet_train_foam2
examples/ct_datagen_foam2
examples/ct_modl_train_foam2
examples/ct_odp_train_foam2
examples/ct_unet_train_foam2
examples/deconv_datagen_bsds
examples/deconv_datagen_foam1
examples/deconv_modl_train_foam1
Expand Down
4 changes: 2 additions & 2 deletions examples/jnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ def py_file_to_string(src):

# Process remainder of source file
for line in srcfile:
if re.match("^input\(", line): # end processing when input statement encountered
if re.match(r"^input\(", line): # end processing when input statement encountered
break
line = re.sub('^r"""', '"""', line) # remove r from r"""
line = re.sub(":cite:\`([^`]+)\`", r'<cite data-cite="\1"/>', line) # fix cite format
line = re.sub(r":cite:\`([^`]+)\`", r'<cite data-cite="\1"/>', line) # fix cite format
lines.append(line)

# Backtrack through list of lines to remove trailing newlines
Expand Down
24 changes: 12 additions & 12 deletions examples/scripts/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ Computed Tomography
PPP (with BM3D) CT Reconstruction (ADMM with Fast SVMBIR Prox)
`ct_fan_svmbir_ppp_bm3d_admm_prox.py <ct_fan_svmbir_ppp_bm3d_admm_prox.py>`_
PPP (with BM3D) Fan-Beam CT Reconstruction
`ct_astra_modl_train_foam2.py <ct_astra_modl_train_foam2.py>`_
CT Training and Reconstructions with MoDL
`ct_astra_odp_train_foam2.py <ct_astra_odp_train_foam2.py>`_
CT Training and Reconstructions with ODP
`ct_astra_unet_train_foam2.py <ct_astra_unet_train_foam2.py>`_
`ct_modl_train_foam2.py <ct_modl_train_foam2.py>`_
CT Training and Reconstruction with MoDL
`ct_odp_train_foam2.py <ct_odp_train_foam2.py>`_
CT Training and Reconstruction with ODP
`ct_unet_train_foam2.py <ct_unet_train_foam2.py>`_
CT Training and Reconstructions with UNet
`ct_projector_comparison_2d.py <ct_projector_comparison_2d.py>`_
2D X-ray Transform Comparison
Expand Down Expand Up @@ -123,7 +123,7 @@ Miscellaneous
TV-Regularized 3D DiffuserCam Reconstruction
`video_rpca_admm.py <video_rpca_admm.py>`_
Video Decomposition via Robust PCA
`ct_astra_datagen_foam2.py <ct_astra_datagen_foam2.py>`_
`ct_datagen_foam2.py <ct_datagen_foam2.py>`_
CT Data Generation for NN Training
`deconv_datagen_bsds.py <deconv_datagen_bsds.py>`_
Blurred Data Generation (Natural Images) for NN Training
Expand Down Expand Up @@ -239,13 +239,13 @@ Sparsity
Machine Learning
^^^^^^^^^^^^^^^^

`ct_astra_datagen_foam2.py <ct_astra_datagen_foam2.py>`_
`ct_datagen_foam2.py <ct_datagen_foam2.py>`_
CT Data Generation for NN Training
`ct_astra_modl_train_foam2.py <ct_astra_modl_train_foam2.py>`_
CT Training and Reconstructions with MoDL
`ct_astra_odp_train_foam2.py <ct_astra_odp_train_foam2.py>`_
CT Training and Reconstructions with ODP
`ct_astra_unet_train_foam2.py <ct_astra_unet_train_foam2.py>`_
`ct_modl_train_foam2.py <ct_modl_train_foam2.py>`_
CT Training and Reconstruction with MoDL
`ct_odp_train_foam2.py <ct_odp_train_foam2.py>`_
CT Training and Reconstruction with ODP
`ct_unet_train_foam2.py <ct_unet_train_foam2.py>`_
CT Training and Reconstructions with UNet
`deconv_datagen_bsds.py <deconv_datagen_bsds.py>`_
Blurred Data Generation (Natural Images) for NN Training
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,17 @@
"""

# isort: off
import os
import numpy as np

import logging
import ray

ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087

# Set an arbitrary processor count (only applies if GPU is not available).
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

from scico import plot
from scico.flax.examples import load_ct_data

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
# with the package.

r"""
CT Training and Reconstructions with MoDL
=========================================
CT Training and Reconstruction with MoDL
========================================

This example demonstrates the training and application of a
model-based deep learning (MoDL) architecture described in
Expand Down Expand Up @@ -65,7 +65,7 @@
from scico import metric, plot
from scico.flax.examples import load_ct_data
from scico.flax.train.traversals import clip_positive, construct_traversal
from scico.linop.xray.astra import XRayTransform2D
from scico.linop.xray import XRayTransform2D

"""
Prepare parallel processing. Set an arbitrary processor count (only
Expand All @@ -89,16 +89,17 @@


"""
Build CT projection operator.
Build CT projection operator. Parameters are chosen so that the operator
is equivalent to the one used to generate the training data.
"""
angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles
A = XRayTransform2D(
input_shape=(N, N),
det_spacing=1,
det_count=N,
angles=angles,
) # CT projection operator
A = (1.0 / N) * A # normalized
det_count=int(N * 1.05 / np.sqrt(2.0)),
dx=1.0 / np.sqrt(2),
)
A = (1.0 / N) * A # normalize projection operator


"""
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/ct_multi_tv_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
np.random.seed(1234)
x_gt = snp.array(discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N))

det_count = N
det_count = int(N * 1.05 / np.sqrt(2.0))
det_spacing = np.sqrt(2)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
# with the package.

r"""
CT Training and Reconstructions with ODP
========================================
CT Training and Reconstruction with ODP
=======================================

This example demonstrates the training of the unrolled optimization with
deep priors (ODP) gradient descent architecture described in
Expand Down Expand Up @@ -72,7 +72,7 @@
from scico import metric, plot
from scico.flax.examples import load_ct_data
from scico.flax.train.traversals import clip_positive, construct_traversal
from scico.linop.xray.astra import XRayTransform2D
from scico.linop.xray import XRayTransform2D


platform = get_backend().platform
Expand All @@ -92,21 +92,22 @@


"""
Build CT projection operator.
Build CT projection operator. Parameters are chosen so that the operator
is equivalent to the one used to generate the training data.
"""
angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles
A = XRayTransform2D(
input_shape=(N, N),
det_spacing=1,
det_count=N,
angles=angles,
) # CT projection operator
A = (1.0 / N) * A # normalized
det_count=int(N * 1.05 / np.sqrt(2.0)),
dx=1.0 / np.sqrt(2),
)
A = (1.0 / N) * A # normalize projection operator


"""
Build training and testing structures. Inputs are the sinograms and
outpus are the original generated foams. Keep training and testing
outputs are the original generated foams. Keep training and testing
partitions.
"""
numtr = 320
Expand Down
8 changes: 3 additions & 5 deletions examples/scripts/ct_projector_comparison_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@
Create a ground truth image.
"""
N = 512

det_count = int(jnp.ceil(jnp.sqrt(2 * N**2)))

x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)
x_gt = jnp.array(x_gt)

Expand All @@ -41,17 +38,18 @@
"""
num_angles = 500
angles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False)
det_count = int(N * 1.02 / jnp.sqrt(2.0))

timer = Timer()

projectors = {}
timer.start("scico_init")
projectors["scico"] = XRayTransform2D((N, N), angles)
projectors["scico"] = XRayTransform2D((N, N), angles, det_count=det_count)
timer.stop("scico_init")

timer.start("astra_init")
projectors["astra"] = astra.XRayTransform2D(
(N, N), det_count=det_count, det_spacing=1.0, angles=angles - jnp.pi / 2.0
(N, N), det_count=det_count, det_spacing=np.sqrt(2), angles=angles - jnp.pi / 2.0
)
timer.stop("astra_init")

Expand Down
16 changes: 8 additions & 8 deletions examples/scripts/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ Computed Tomography
- ct_svmbir_ppp_bm3d_admm_cg.py
- ct_svmbir_ppp_bm3d_admm_prox.py
- ct_fan_svmbir_ppp_bm3d_admm_prox.py
- ct_astra_modl_train_foam2.py
- ct_astra_odp_train_foam2.py
- ct_astra_unet_train_foam2.py
- ct_modl_train_foam2.py
- ct_odp_train_foam2.py
- ct_unet_train_foam2.py
- ct_projector_comparison_2d.py
- ct_projector_comparison_3d.py
- ct_multi_tv_admm.py
Expand Down Expand Up @@ -73,7 +73,7 @@ Miscellaneous
- denoise_dncnn_universal.py
- diffusercam_tv_admm.py
- video_rpca_admm.py
- ct_astra_datagen_foam2.py
- ct_datagen_foam2.py
- deconv_datagen_bsds.py
- deconv_datagen_foam1.py
- denoise_datagen_bsds.py
Expand Down Expand Up @@ -143,10 +143,10 @@ Sparsity
Machine Learning
^^^^^^^^^^^^^^^^

- ct_astra_datagen_foam2.py
- ct_astra_modl_train_foam2.py
- ct_astra_odp_train_foam2.py
- ct_astra_unet_train_foam2.py
- ct_datagen_foam2.py
- ct_modl_train_foam2.py
- ct_odp_train_foam2.py
- ct_unet_train_foam2.py
- deconv_datagen_bsds.py
- deconv_datagen_foam1.py
- deconv_modl_train_foam1.py
Expand Down
1 change: 0 additions & 1 deletion misc/conda/install_conda.sh
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ rm -f /tmp/miniconda.sh
export PATH="$CONDAHOME/bin:$PATH"
hash -r
conda config --set always_yes yes
conda install mamba -n base -c conda-forge
conda update -q conda
conda info -a

Expand Down
11 changes: 4 additions & 7 deletions misc/conda/make_conda_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ EOF
)
# Requirements that cannot be installed via conda (i.e. have to use pip)
NOCONDA=$(cat <<-EOF
flax bm3d bm4d py2jn colour_demosaicing hyperopt ray[tune,train]
flax orbax-checkpoint bm3d bm4d py2jn colour_demosaicing hyperopt ray[tune,train]
EOF
)

Expand Down Expand Up @@ -217,19 +217,16 @@ eval "$(conda shell.bash hook)" # required to avoid errors re: `conda init`
conda activate $ENVNM # Q: why not `source activate`? A: not always in the path

# Add conda-forge channel
conda config --env --append channels conda-forge

# Install mamba
conda install mamba -n base -c conda-forge
conda config --append channels conda-forge

# Install required conda packages (and extra useful packages)
mamba install $CONDA_FLAGS $CONDAREQ ipython
conda install $CONDA_FLAGS $CONDAREQ ipython

# Utility ffmpeg is required by imageio for reading mp4 video files
# it can also be installed via the system package manager, .e.g.
# sudo apt install ffmpeg
if [ "$(which ffmpeg)" = '' ]; then
mamba install $CONDA_FLAGS ffmpeg
conda install $CONDA_FLAGS ffmpeg
fi

# Install jaxlib and jax
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ scipy>=1.6.0
imageio>=2.17
tifffile
matplotlib
jaxlib>=0.4.3,<=0.4.33
jax>=0.4.3,<=0.4.33
jaxlib>=0.4.3,<=0.4.34
jax>=0.4.3,<=0.4.34
orbax-checkpoint>=0.5.0
flax>=0.8.0,<=0.9.0
pyabel>=0.9.0
Loading
Loading