diff --git a/.github/workflows/pytest_ubuntu.yml b/.github/workflows/pytest_ubuntu.yml index ffad0c04f..babfb1d9f 100644 --- a/.github/workflows/pytest_ubuntu.yml +++ b/.github/workflows/pytest_ubuntu.yml @@ -89,7 +89,7 @@ jobs: - name: Run doc tests if: matrix.group == 1 run: | - pytest --ignore-glob="*test_*.py" --doctest-modules scico + pytest --ignore-glob="*test_*.py" --ignore=scico/linop/xray --doctest-modules scico pytest --doctest-glob="*.rst" docs coverage: diff --git a/CHANGES.rst b/CHANGES.rst index 798ba887f..a2e246978 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -11,6 +11,8 @@ Version 0.0.6 (unreleased) ``linop.Parallel3dProjector``. • New functional ``functional.IsotropicTVNorm`` and faster implementation of ``functional.AnisotropicTVNorm``. +• New linear operators ``linop.ProjectedGradient``, ``linop.PolarGradient``, + ``linop.CylindricalGradient``, and ``linop.SphericalGradient``. • Rename ``scico.numpy.util.parse_axes`` to ``scico.numpy.util.normalize_axes``. • Rename ``scico.flax.save_weights`` and ``scico.flax.load_weights`` to diff --git a/docs/docs_requirements.txt b/docs/docs_requirements.txt index 1e55847bf..fbd4cafc4 100644 --- a/docs/docs_requirements.txt +++ b/docs/docs_requirements.txt @@ -3,13 +3,13 @@ sphinx>=5.0.0 sphinxcontrib-napoleon sphinxcontrib-bibtex sphinx-autodoc-typehints -furo +furo>=2024.5.6 jinja2<3.1.0 # temporary fix for jinja2/nbconvert bug traitlets!=5.2.2 # temporary fix for ipython/traitlets#741 nbsphinx ipython ipython_genutils py2jn -pygraphviz>=1.7 +pygraphviz>=1.9 pandoc docutils>=0.18 diff --git a/docs/source/conf/15-theme.py b/docs/source/conf/15-theme.py index 35ad6d669..8d20c3199 100644 --- a/docs/source/conf/15-theme.py +++ b/docs/source/conf/15-theme.py @@ -6,7 +6,7 @@ html_theme = "furo" html_theme_options = { - "top_of_page_button": None, + "top_of_page_buttons": [], # "sidebar_hide_name": True, } diff --git a/docs/source/examples.rst b/docs/source/examples.rst index 43dad7e3b..58ba847f0 100644 --- a/docs/source/examples.rst +++ b/docs/source/examples.rst @@ -86,6 +86,7 @@ Miscellaneous examples/demosaic_ppp_bm3d_admm examples/superres_ppp_dncnn_admm examples/denoise_l1tv_admm + examples/denoise_ptv_pdhg examples/denoise_tv_admm examples/denoise_tv_apgm examples/denoise_tv_multi @@ -138,6 +139,7 @@ Total Variation examples/ct_astra_3d_tv_admm examples/ct_astra_3d_tv_padmm examples/ct_astra_weighted_tv_admm + examples/ct_multi_tv_admm examples/ct_svmbir_tv_multi examples/deconv_circ_tv_admm examples/deconv_tv_admm @@ -146,6 +148,7 @@ Total Variation examples/deconv_microscopy_tv_admm examples/deconv_microscopy_allchn_tv_admm examples/denoise_l1tv_admm + examples/denoise_ptv_pdhg examples/denoise_tv_admm examples/denoise_tv_apgm examples/denoise_tv_multi @@ -209,6 +212,7 @@ ADMM examples/ct_tv_admm examples/ct_astra_3d_tv_admm examples/ct_astra_weighted_tv_admm + examples/ct_multi_tv_admm examples/ct_svmbir_tv_multi examples/ct_svmbir_ppp_bm3d_admm_cg examples/ct_svmbir_ppp_bm3d_admm_prox @@ -272,6 +276,7 @@ PDHG :maxdepth: 1 examples/ct_svmbir_tv_multi + examples/denoise_ptv_pdhg examples/denoise_tv_multi examples/denoise_cplx_tv_pdhg diff --git a/docs/source/references.bib b/docs/source/references.bib index a150a9b07..257f24287 100644 --- a/docs/source/references.bib +++ b/docs/source/references.bib @@ -362,6 +362,15 @@ @Book {goodman-2005-fourier edition = 3 } +@Misc {hossein-2024-total, + title = {Total Variation Regularization for Tomographic + Reconstruction of Cylindrically Symmetric Objects}, + author = {Maliha Hossain and Charles A. Bouman and Brendt + Wohlberg}, + year = 2024, + eprint = {2406.17928} +} + @Article {huber-1964-robust, doi = {10.1214/aoms/1177703732}, year = 1964, @@ -776,6 +785,7 @@ @Article {zhang-2021-plug pages = {6360--6376} } + @Article {zhou-2006-adaptive, author = {Bin Zhou and Li Gao and Yu-Hong Dai}, title = {Gradient Methods with Adaptive Step-Sizes}, diff --git a/docs/source/style.rst b/docs/source/style.rst index f26bcb85a..80fe8fd27 100644 --- a/docs/source/style.rst +++ b/docs/source/style.rst @@ -439,7 +439,7 @@ classes, or method definitions. Comment explaining example 1. - >>> np.add(1, 2) + >>> int(np.add(1, 2)) 3 Comment explaining a new example. diff --git a/examples/jnb.py b/examples/jnb.py index 5b8ead6ee..9135add1d 100644 --- a/examples/jnb.py +++ b/examples/jnb.py @@ -33,11 +33,12 @@ def py_file_to_string(src): if import_seen: # Once an import statement has been seen, break on encountering a line that # is neither an import statement nor a newline, nor a component of an import - # statement extended over multiple lines, nor an os.environ statement, nor - # components of a try/except construction (note that handling of these final - # two cases is probably not very robust). + # statement extended over multiple lines, nor an os.environ statement, nor a + # ray.init statement, nor components of a try/except construction (note that + # handling of these final two cases is probably not very robust). if not re.match( - r"(^import|^from|^\n$|^\W+[^\W]|^\)$|^os.environ|^try:$|^except)", line + r"(^import|^from|^\n$|^\W+[^\W]|^\)$|^os.environ|^ray.init|^try:$|^except)", + line, ): lines.append(line) break diff --git a/examples/notebooks_requirements.txt b/examples/notebooks_requirements.txt index bcb9e03c9..644a4db21 100644 --- a/examples/notebooks_requirements.txt +++ b/examples/notebooks_requirements.txt @@ -1,4 +1,6 @@ -r examples-requirements.txt +ipykernel +ipywidgets nbformat nbconvert nb_conda_kernels diff --git a/examples/scripts/README.rst b/examples/scripts/README.rst index b4fe1a740..446186a76 100644 --- a/examples/scripts/README.rst +++ b/examples/scripts/README.rst @@ -43,8 +43,6 @@ Computed Tomography 2D X-ray Transform Comparison `ct_projector_comparison_3d.py `_ 3D X-ray Transform Comparison - `ct_multi_cs_tv_admm.py `_ - TV-Regularized Sparse-View CT Reconstruction (Multiple Projectors, Common Sinogram) `ct_multi_tv_admm.py `_ TV-Regularized Sparse-View CT Reconstruction (Multiple Projectors) @@ -105,6 +103,8 @@ Miscellaneous PPP (with DnCNN) Image Superresolution `denoise_l1tv_admm.py `_ ℓ1 Total Variation Denoising + `denoise_ptv_pdhg.py `_ + Polar Total Variation Denoising (PDHG) `denoise_tv_admm.py `_ Total Variation Denoising (ADMM) `denoise_tv_apgm.py `_ @@ -178,6 +178,8 @@ Total Variation 3D TV-Regularized Sparse-View CT Reconstruction (Proximal ADMM Solver) `ct_astra_weighted_tv_admm.py `_ TV-Regularized Low-Dose CT Reconstruction + `ct_multi_tv_admm.py `_ + TV-Regularized Sparse-View CT Reconstruction (Multiple Projectors) `ct_svmbir_tv_multi.py `_ TV-Regularized CT Reconstruction (Multiple Algorithms) `deconv_circ_tv_admm.py `_ @@ -194,6 +196,8 @@ Total Variation Deconvolution Microscopy (All Channels) `denoise_l1tv_admm.py `_ ℓ1 Total Variation Denoising + `denoise_ptv_pdhg.py `_ + Polar Total Variation Denoising (PDHG) `denoise_tv_admm.py `_ Total Variation Denoising (ADMM) `denoise_tv_apgm.py `_ @@ -277,6 +281,8 @@ ADMM 3D TV-Regularized Sparse-View CT Reconstruction (ADMM Solver) `ct_astra_weighted_tv_admm.py `_ TV-Regularized Low-Dose CT Reconstruction + `ct_multi_tv_admm.py `_ + TV-Regularized Sparse-View CT Reconstruction (Multiple Projectors) `ct_svmbir_tv_multi.py `_ TV-Regularized CT Reconstruction (Multiple Algorithms) `ct_svmbir_ppp_bm3d_admm_cg.py `_ @@ -359,6 +365,8 @@ PDHG `ct_svmbir_tv_multi.py `_ TV-Regularized CT Reconstruction (Multiple Algorithms) + `denoise_ptv_pdhg.py `_ + Polar Total Variation Denoising (PDHG) `denoise_tv_multi.py `_ Comparison of Optimization Algorithms for Total Variation Denoising `denoise_cplx_tv_pdhg.py `_ diff --git a/examples/scripts/ct_multi_cs_tv_admm.py b/examples/scripts/ct_multi_cs_tv_admm.py deleted file mode 100644 index cb8ddc6b6..000000000 --- a/examples/scripts/ct_multi_cs_tv_admm.py +++ /dev/null @@ -1,186 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# This file is part of the SCICO package. Details of the copyright -# and user license can be found in the 'LICENSE.txt' file distributed -# with the package. - -r""" -TV-Regularized Sparse-View CT Reconstruction (Multiple Projectors, Common Sinogram) -=================================================================================== - -This example demonstrates solution of a sparse-view CT reconstruction -problem with isotropic total variation (TV) regularization - - $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - A \mathbf{x} - \|_2^2 + \lambda \| C \mathbf{x} \|_{2,1} \;,$$ - -where $A$ is the X-ray transform (the CT forward projection operator), -$\mathbf{y}$ is the sinogram, $C$ is a 2D finite difference operator, and -$\mathbf{x}$ is the desired image. The solution is computed and compared -for all three 2D CT projectors available in scico, using a sinogram -computed with the svmbir projector. -""" - -import numpy as np - -from xdesign import Foam, discrete_phantom - -import scico.numpy as snp -from scico import functional, linop, loss, metric, plot -from scico.linop.xray import Parallel2dProjector, XRayTransform, astra, svmbir -from scico.optimize.admm import ADMM, LinearSubproblemSolver -from scico.util import device_info - -""" -Create a ground truth image. -""" -N = 512 # phantom size -np.random.seed(1234) -x_gt = snp.array(discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)) - -det_count = N -det_spacing = np.sqrt(2) - - -""" -Define CT geometry and construct array of (approximately) equivalent projectors. -""" -n_projection = 45 # number of projections -angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles -projectors = { - "astra": astra.XRayTransform2D( - x_gt.shape, det_count, det_spacing, angles - np.pi / 2.0 - ), # astra - "svmbir": svmbir.XRayTransform( - x_gt.shape, 2 * np.pi - angles, det_count, delta_pixel=1.0, delta_channel=det_spacing - ), # svmbir - "scico": XRayTransform( - Parallel2dProjector((N, N), angles, det_count=det_count, dx=1 / det_spacing) - ), # scico -} - - -""" -Compute common sinogram using svmbir projector. -""" -A = projectors["astra"] -noise = np.random.normal(size=(n_projection, det_count)).astype(np.float32) -y = A @ x_gt + 2.0 * noise - - -""" -Construct initial solution for regularized problem. -""" -x0 = A.fbp(y) - - -""" -Solve the same problem using the different projectors. -""" -print(f"Solving on {device_info()}") -x_rec, hist = {}, {} -for p in projectors.keys(): - print(f"\nSolving with {p} projector") - - # Set up ADMM solver object. - λ = 2e1 # L1 norm regularization parameter - ρ = 1e3 # ADMM penalty parameter - maxiter = 100 # number of ADMM iterations - cg_tol = 1e-4 # CG relative tolerance - cg_maxiter = 50 # maximum CG iterations per ADMM iteration - - # The append=0 option makes the results of horizontal and vertical - # finite differences the same shape, which is required for the L21Norm, - # which is used so that g(Cx) corresponds to isotropic TV. - C = linop.FiniteDifference(input_shape=x_gt.shape, append=0) - g = λ * functional.L21Norm() - A = projectors[p] - f = loss.SquaredL2Loss(y=y, A=A) - - # Set up the solver. - solver = ADMM( - f=f, - g_list=[g], - C_list=[C], - rho_list=[ρ], - x0=x0, - maxiter=maxiter, - subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": cg_tol, "maxiter": cg_maxiter}), - itstat_options={"display": True, "period": 5}, - ) - - # Run the solver. - solver.solve() - hist[p] = solver.itstat_object.history(transpose=True) - x_rec[p] = solver.x - - if p == "scico": - x_rec[p] = x_rec[p] * det_spacing # to match ASTRA's scaling - - -""" -Compare reconstruction results. -""" -print("Reconstruction SNR:") -for p in projectors.keys(): - print(f" {(p + ':'):7s} {metric.snr(x_gt, x_rec[p]):5.2f} dB") - - -""" -Display sinogram. -""" -fig, ax = plot.subplots(nrows=1, ncols=1, figsize=(15, 3)) -plot.imview(y, title="sinogram", fig=fig, ax=ax) -fig.show() - - -""" -Plot convergence statistics. -""" -fig, ax = plot.subplots(nrows=1, ncols=3, figsize=(12, 5)) -plot.plot( - np.vstack([hist[p].Objective for p in projectors.keys()]).T, - title="Objective function", - xlbl="Iteration", - ylbl="Functional value", - lgnd=projectors.keys(), - fig=fig, - ax=ax[0], -) -plot.plot( - np.vstack([hist[p].Prml_Rsdl for p in projectors.keys()]).T, - ptyp="semilogy", - title="Primal Residual", - xlbl="Iteration", - fig=fig, - ax=ax[1], -) -plot.plot( - np.vstack([hist[p].Dual_Rsdl for p in projectors.keys()]).T, - ptyp="semilogy", - title="Dual Residual", - xlbl="Iteration", - fig=fig, - ax=ax[2], -) -fig.show() - - -""" -Show the recovered images. -""" -fig, ax = plot.subplots(nrows=1, ncols=4, figsize=(15, 5)) -plot.imview(x_gt, title="Ground truth", fig=fig, ax=ax[0]) -for n, p in enumerate(projectors.keys()): - plot.imview( - x_rec[p], - title="%s SNR: %.2f (dB)" % (p, metric.snr(x_gt, x_rec[p])), - fig=fig, - ax=ax[n + 1], - ) -for ax in ax: - ax.get_images()[0].set_clim(-0.1, 1.1) -fig.show() - - -input("\nWaiting for input to close figures and exit") diff --git a/examples/scripts/ct_multi_tv_admm.py b/examples/scripts/ct_multi_tv_admm.py index 084ad99ac..87fc86865 100644 --- a/examples/scripts/ct_multi_tv_admm.py +++ b/examples/scripts/ct_multi_tv_admm.py @@ -17,7 +17,8 @@ where $A$ is the X-ray transform (the CT forward projection operator), $\mathbf{y}$ is the sinogram, $C$ is a 2D finite difference operator, and $\mathbf{x}$ is the desired image. The solution is computed and compared -for all three 2D CT projectors available in scico. +for all three 2D CT projectors available in scico, using a sinogram +computed with the astra projector. """ import numpy as np @@ -37,6 +38,9 @@ np.random.seed(1234) x_gt = snp.array(discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)) +det_count = N +det_spacing = np.sqrt(2) + """ Define CT geometry and construct array of (approximately) equivalent projectors. @@ -44,37 +48,54 @@ n_projection = 45 # number of projections angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles projectors = { - "astra": astra.XRayTransform2D(x_gt.shape, N, 1.0, angles - np.pi / 2.0), # astra - "svmbir": svmbir.XRayTransform(x_gt.shape, 2 * np.pi - angles, N), # svmbir - "scico": XRayTransform(Parallel2dProjector((N, N), angles, det_count=N)), # scico + "astra": astra.XRayTransform2D( + x_gt.shape, det_count, det_spacing, angles - np.pi / 2.0 + ), # astra + "svmbir": svmbir.XRayTransform( + x_gt.shape, 2 * np.pi - angles, det_count, delta_pixel=1.0, delta_channel=det_spacing + ), # svmbir + "scico": XRayTransform( + Parallel2dProjector((N, N), angles, det_count=det_count, dx=1 / det_spacing) + ), # scico } +""" +Compute common sinogram using astra projector. +""" +A = projectors["astra"] +noise = np.random.normal(size=(n_projection, det_count)).astype(np.float32) +y = A @ x_gt + 2.0 * noise + + +""" +Construct initial solution for regularized problem. +""" +x0 = A.fbp(y) + + """ Solve the same problem using the different projectors. """ print(f"Solving on {device_info()}") -y, x_rec, hist = {}, {}, {} -noise = np.random.normal(size=(n_projection, N)).astype(np.float32) -for p in ("astra", "svmbir", "scico"): +x_rec, hist = {}, {} +for p in projectors.keys(): print(f"\nSolving with {p} projector") - A = projectors[p] - y[p] = A @ x_gt + 2.0 * noise # sinogram # Set up ADMM solver object. - λ = 2e0 # L1 norm regularization parameter - ρ = 5e0 # ADMM penalty parameter - maxiter = 25 # number of ADMM iterations + λ = 2e1 # L1 norm regularization parameter + ρ = 1e3 # ADMM penalty parameter + maxiter = 100 # number of ADMM iterations cg_tol = 1e-4 # CG relative tolerance - cg_maxiter = 25 # maximum CG iterations per ADMM iteration + cg_maxiter = 50 # maximum CG iterations per ADMM iteration # The append=0 option makes the results of horizontal and vertical # finite differences the same shape, which is required for the L21Norm, # which is used so that g(Cx) corresponds to isotropic TV. C = linop.FiniteDifference(input_shape=x_gt.shape, append=0) g = λ * functional.L21Norm() - f = loss.SquaredL2Loss(y=y[p], A=A) - x0 = snp.clip(A.T(y[p]), 0, 1.0) + A = projectors[p] + f = loss.SquaredL2Loss(y=y, A=A) # Set up the solver. solver = ADMM( @@ -91,15 +112,25 @@ # Run the solver. solver.solve() hist[p] = solver.itstat_object.history(transpose=True) - x_rec[p] = snp.clip(solver.x, 0, 1.0) + x_rec[p] = solver.x + + if p == "scico": + x_rec[p] = x_rec[p] * det_spacing # to match ASTRA's scaling + + +""" +Compare reconstruction results. +""" +print("Reconstruction SNR:") +for p in projectors.keys(): + print(f" {(p + ':'):7s} {metric.snr(x_gt, x_rec[p]):5.2f} dB") """ -Compare sinograms. +Display sinogram. """ -fig, ax = plot.subplots(nrows=3, ncols=1, figsize=(15, 10)) -for idx, name in enumerate(projectors.keys()): - plot.imview(y[name], title=f"{name} sinogram", cbar=None, fig=fig, ax=ax[idx]) +fig, ax = plot.subplots(nrows=1, ncols=1, figsize=(15, 3)) +plot.imview(y, title="sinogram", fig=fig, ax=ax) fig.show() @@ -147,6 +178,8 @@ fig=fig, ax=ax[n + 1], ) +for ax in ax: + ax.get_images()[0].set_clim(-0.1, 1.1) fig.show() diff --git a/examples/scripts/deconv_circ_tv_admm.py b/examples/scripts/deconv_circ_tv_admm.py index b2ba83202..90e898dc7 100644 --- a/examples/scripts/deconv_circ_tv_admm.py +++ b/examples/scripts/deconv_circ_tv_admm.py @@ -54,7 +54,7 @@ """ Set up an ADMM solver object. """ -λ = 2e-2 # L21 norm regularization parameter +λ = 2e-2 # ℓ2,1 norm regularization parameter ρ = 5e-1 # ADMM penalty parameter maxiter = 50 # number of ADMM iterations diff --git a/examples/scripts/deconv_microscopy_allchn_tv_admm.py b/examples/scripts/deconv_microscopy_allchn_tv_admm.py index 13ec27b87..ccd8c9471 100644 --- a/examples/scripts/deconv_microscopy_allchn_tv_admm.py +++ b/examples/scripts/deconv_microscopy_allchn_tv_admm.py @@ -28,19 +28,24 @@ non-negativity constraint, and $\mathbf{x}$ is the desired image. """ +# isort: off import numpy as np +import logging import ray + +ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087 + import scico.numpy as snp from scico import functional, linop, loss, plot from scico.examples import downsample_volume, epfl_deconv_data, tile_volume_slices from scico.optimize.admm import ADMM, CircularConvolveSolver """ -Get and preprocess data. We downsample the data for the for purposes of -the example. Reducing the downsampling rate will make the example slower -and more memory-intensive. To run this example on a GPU it may be -necessary to set environment variables +Get and preprocess data. The data is downsampled to limit the memory +requirements and run time of the example. Reducing the downsampling rate +will make the example slower and more memory-intensive. To run this +example on a GPU it may be necessary to set environment variables `XLA_PYTHON_CLIENT_ALLOCATOR=platform` and `XLA_PYTHON_CLIENT_PREALLOCATE=false`. If your GPU does not have enough memory, you can try setting the environment variable @@ -81,11 +86,9 @@ """ -Initialize ray, determine available computing resources, and put large arrays -in object store. +Determine available computing resources, and put large arrays in ray +object store. """ -ray.init() - ngpu = 0 ar = ray.available_resources() ncpu = max(int(ar["CPU"]) // 3, 1) diff --git a/examples/scripts/deconv_microscopy_tv_admm.py b/examples/scripts/deconv_microscopy_tv_admm.py index eee4d8f42..5c9180b60 100644 --- a/examples/scripts/deconv_microscopy_tv_admm.py +++ b/examples/scripts/deconv_microscopy_tv_admm.py @@ -34,10 +34,10 @@ from scico.optimize.admm import ADMM, CircularConvolveSolver """ -Get and preprocess data. We downsample the data for the for purposes of -the example. Reducing the downsampling rate will make the example slower -and more memory-intensive. To run this example on a GPU it may be -necessary to set environment variables +Get and preprocess data. The data is downsampled to limit the memory +requirements and run time of the example. Reducing the downsampling rate +will make the example slower and more memory-intensive. To run this +example on a GPU it may be necessary to set environment variables `XLA_PYTHON_CLIENT_ALLOCATOR=platform` and `XLA_PYTHON_CLIENT_PREALLOCATE=false`. If your GPU does not have enough memory, you can try setting the environment variable diff --git a/examples/scripts/deconv_tv_admm.py b/examples/scripts/deconv_tv_admm.py index 874b70e3c..8acb54d36 100644 --- a/examples/scripts/deconv_tv_admm.py +++ b/examples/scripts/deconv_tv_admm.py @@ -77,7 +77,7 @@ f = loss.SquaredL2Loss(y=y, A=C) # Penalty parameters must be accounted for in the gi functions, not as # additional inputs. -λ = 2.1e-2 # L21 norm regularization parameter +λ = 2.1e-2 # ℓ2,1 norm regularization parameter g = λ * functional.L21Norm() # The append=0 option makes the results of horizontal and vertical # finite differences the same shape, which is required for the L21Norm, diff --git a/examples/scripts/denoise_ptv_pdhg.py b/examples/scripts/denoise_ptv_pdhg.py new file mode 100644 index 000000000..ec5db49c9 --- /dev/null +++ b/examples/scripts/denoise_ptv_pdhg.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# This file is part of the SCICO package. Details of the copyright +# and user license can be found in the 'LICENSE.txt' file distributed +# with the package. + +r""" +Polar Total Variation Denoising (PDHG) +====================================== + +This example compares denoising via standard isotropic total +variation (TV) regularization :cite:`rudin-1992-nonlinear` +:cite:`goldstein-2009-split` and a variant based on local polar +coordinates, as described in :cite:`hossein-2024-total`. It solves the +denoising problem + + $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - \mathbf{x} + \|_2^2 + \lambda R(\mathbf{x}) \;,$$ + +where $R$ is either the isotropic or polar TV regularizer, via the +primal–dual hybrid gradient (PDHG) algorithm. +""" + + +from xdesign import SiemensStar, discrete_phantom + +import scico.numpy as snp +import scico.random +from scico import functional, linop, loss, metric, plot +from scico.optimize import PDHG +from scico.util import device_info + +""" +Create a ground truth image. +""" +N = 256 # image size +phantom = SiemensStar(16) +x_gt = snp.pad(discrete_phantom(phantom, N - 16), 8) +x_gt = x_gt / x_gt.max() + + +""" +Add noise to create a noisy test image. +""" +σ = 0.75 # noise standard deviation +noise, key = scico.random.randn(x_gt.shape, seed=0) +y = x_gt + σ * noise + + +""" +Denoise with standard isotropic total variation. +""" +λ_std = 0.8e0 +f = loss.SquaredL2Loss(y=y) +g_std = λ_std * functional.L21Norm() + +# The append=0 option makes the results of horizontal and vertical finite +# differences the same shape, which is required for the L21Norm. +C = linop.FiniteDifference(input_shape=x_gt.shape, append=0) +tau, sigma = PDHG.estimate_parameters(C, ratio=20.0) +solver = PDHG( + f=f, + g=g_std, + C=C, + tau=tau, + sigma=sigma, + maxiter=200, + itstat_options={"display": True, "period": 10}, +) +print(f"Solving on {device_info()}\n") +solver.solve() +hist_std = solver.itstat_object.history(transpose=True) +x_std = solver.x +print() + + +""" +Denoise with polar total variation for comparison. +""" +# Tune the weight to give the same data fidelty as the isotropic case. +λ_plr = 1.2e0 +g_plr = λ_plr * functional.L1Norm() + +G = linop.PolarGradient(input_shape=x_gt.shape) +D = linop.Diagonal(snp.blockarray([0.3, 1.0]), input_shape=G.shape[0]) +C = D @ G + +tau, sigma = PDHG.estimate_parameters(C, ratio=20.0) +solver = PDHG( + f=f, + g=g_plr, + C=C, + tau=tau, + sigma=sigma, + maxiter=200, + itstat_options={"display": True, "period": 10}, +) +solver.solve() +hist_plr = solver.itstat_object.history(transpose=True) +x_plr = solver.x +print() + + +""" +Compute and print the data fidelity. +""" +for x, name in zip((x_std, x_plr), ("Isotropic", "Polar")): + df = f(x) + print(f"Data fidelity for {(name + ' TV'):12}: {df:.2e} SNR: {metric.snr(x_gt, x):5.2f} dB") + + +""" +Plot results. +""" +plt_args = dict(norm=plot.matplotlib.colors.Normalize(vmin=0, vmax=1.5)) +fig, ax = plot.subplots(nrows=2, ncols=2, sharex=True, sharey=True, figsize=(11, 10)) +plot.imview(x_gt, title="Ground truth", fig=fig, ax=ax[0, 0], **plt_args) +plot.imview(y, title="Noisy version", fig=fig, ax=ax[0, 1], **plt_args) +plot.imview(x_std, title="Isotropic TV denoising", fig=fig, ax=ax[1, 0], **plt_args) +plot.imview(x_plr, title="Polar TV denoising", fig=fig, ax=ax[1, 1], **plt_args) +fig.subplots_adjust(left=0.1, right=0.99, top=0.95, bottom=0.05, wspace=0.2, hspace=0.01) +fig.colorbar( + ax[0, 0].get_images()[0], ax=ax, location="right", shrink=0.9, pad=0.05, label="Arbitrary Units" +) +fig.suptitle("Denoising comparison") +fig.show() + +# zoomed version +fig, ax = plot.subplots(nrows=2, ncols=2, sharex=True, sharey=True, figsize=(11, 10)) +plot.imview(x_gt, title="Ground truth", fig=fig, ax=ax[0, 0], **plt_args) +plot.imview(y, title="Noisy version", fig=fig, ax=ax[0, 1], **plt_args) +plot.imview(x_std, title="Isotropic TV denoising", fig=fig, ax=ax[1, 0], **plt_args) +plot.imview(x_plr, title="Polar TV denoising", fig=fig, ax=ax[1, 1], **plt_args) +ax[0, 0].set_xlim(N // 4, N // 4 + N // 2) +ax[0, 0].set_ylim(N // 4, N // 4 + N // 2) +fig.subplots_adjust(left=0.1, right=0.99, top=0.95, bottom=0.05, wspace=0.2, hspace=0.01) +fig.colorbar( + ax[0, 0].get_images()[0], ax=ax, location="right", shrink=0.9, pad=0.05, label="Arbitrary Units" +) +fig.suptitle("Denoising comparison (zoomed)") +fig.show() + + +fig, ax = plot.subplots(nrows=1, ncols=3, sharex=True, sharey=False, figsize=(20, 5)) +plot.plot( + snp.vstack((hist_std.Objective, hist_plr.Objective)).T, + ptyp="semilogy", + title="Objective function", + xlbl="Iteration", + lgnd=("Standard", "Polar"), + fig=fig, + ax=ax[0], +) +plot.plot( + snp.vstack((hist_std.Prml_Rsdl, hist_plr.Prml_Rsdl)).T, + ptyp="semilogy", + title="Primal residual", + xlbl="Iteration", + lgnd=("Standard", "Polar"), + fig=fig, + ax=ax[1], +) +plot.plot( + snp.vstack((hist_std.Dual_Rsdl, hist_plr.Dual_Rsdl)).T, + ptyp="semilogy", + title="Dual residual", + xlbl="Iteration", + lgnd=("Standard", "Polar"), + fig=fig, + ax=ax[2], +) +fig.show() + + +input("\nWaiting for input to close figures and exit") diff --git a/examples/scripts/index.rst b/examples/scripts/index.rst index ef592a444..4f05ba2fc 100644 --- a/examples/scripts/index.rst +++ b/examples/scripts/index.rst @@ -26,7 +26,6 @@ Computed Tomography - ct_astra_unet_train_foam2.py - ct_projector_comparison_2d.py - ct_projector_comparison_3d.py - - ct_multi_cs_tv_admm.py - ct_multi_tv_admm.py Deconvolution @@ -64,6 +63,7 @@ Miscellaneous - demosaic_ppp_bm3d_admm.py - superres_ppp_dncnn_admm.py - denoise_l1tv_admm.py + - denoise_ptv_pdhg.py - denoise_tv_admm.py - denoise_tv_apgm.py - denoise_tv_multi.py @@ -107,6 +107,7 @@ Total Variation - ct_astra_3d_tv_admm.py - ct_astra_3d_tv_padmm.py - ct_astra_weighted_tv_admm.py + - ct_multi_tv_admm.py - ct_svmbir_tv_multi.py - deconv_circ_tv_admm.py - deconv_tv_admm.py @@ -115,6 +116,7 @@ Total Variation - deconv_microscopy_tv_admm.py - deconv_microscopy_allchn_tv_admm.py - denoise_l1tv_admm.py + - denoise_ptv_pdhg.py - denoise_tv_admm.py - denoise_tv_apgm.py - denoise_tv_multi.py @@ -166,6 +168,7 @@ ADMM - ct_tv_admm.py - ct_astra_3d_tv_admm.py - ct_astra_weighted_tv_admm.py + - ct_multi_tv_admm.py - ct_svmbir_tv_multi.py - ct_svmbir_ppp_bm3d_admm_cg.py - ct_svmbir_ppp_bm3d_admm_prox.py @@ -217,6 +220,7 @@ PDHG ^^^^ - ct_svmbir_tv_multi.py + - denoise_ptv_pdhg.py - denoise_tv_multi.py - denoise_cplx_tv_pdhg.py diff --git a/examples/scripts/sparsecode_conv_admm.py b/examples/scripts/sparsecode_conv_admm.py index 611e7238b..0c46356a3 100644 --- a/examples/scripts/sparsecode_conv_admm.py +++ b/examples/scripts/sparsecode_conv_admm.py @@ -75,7 +75,7 @@ """ Set functional and solver parameters. """ -λ = 1e0 # l1-l2 norm regularization parameter +λ = 1e0 # ℓ1-ℓ2 norm regularization parameter ρ = 2e0 # ADMM penalty parameter maxiter = 200 # number of ADMM iterations diff --git a/examples/scripts/sparsecode_conv_md_admm.py b/examples/scripts/sparsecode_conv_md_admm.py index 9f70e9ce3..77d24f19b 100644 --- a/examples/scripts/sparsecode_conv_md_admm.py +++ b/examples/scripts/sparsecode_conv_md_admm.py @@ -96,7 +96,7 @@ """ Set functional and solver parameters. """ -λ = 1e0 # l1-l2 norm regularization parameter +λ = 1e0 # ℓ1-ℓ2 norm regularization parameter ρ0 = 1e0 # ADMM penalty parameters ρ1 = 3e0 maxiter = 200 # number of ADMM iterations diff --git a/examples/scripts/video_rpca_admm.py b/examples/scripts/video_rpca_admm.py index e6c71cb8e..ee8fa5c14 100644 --- a/examples/scripts/video_rpca_admm.py +++ b/examples/scripts/video_rpca_admm.py @@ -62,7 +62,7 @@ Set up an ADMM solver object. """ λ0 = 1e1 # nuclear norm regularization parameter -λ1 = 3e1 # l1 norm regularization parameter +λ1 = 3e1 # ℓ1 norm regularization parameter ρ0 = 2e1 # ADMM penalty parameter ρ1 = 2e1 # ADMM penalty parameter maxiter = 50 # number of ADMM iterations diff --git a/scico/linop/__init__.py b/scico/linop/__init__.py index cc23727ea..b0149930d 100644 --- a/scico/linop/__init__.py +++ b/scico/linop/__init__.py @@ -15,6 +15,12 @@ from ._diag import Diagonal, Identity, ScaledIdentity from ._diff import FiniteDifference, SingleAxisFiniteDifference from ._func import Crop, Pad, Reshape, Slice, Sum, Transpose, linop_from_function +from ._grad import ( + CylindricalGradient, + PolarGradient, + ProjectedGradient, + SphericalGradient, +) from ._linop import ComposedLinearOperator, LinearOperator from ._matrix import MatrixOperator from ._stack import DiagonalReplicated, DiagonalStack, VerticalStack, linop_over_axes @@ -27,6 +33,10 @@ "DFT", "Diagonal", "FiniteDifference", + "ProjectedGradient", + "PolarGradient", + "CylindricalGradient", + "SphericalGradient", "SingleAxisFiniteDifference", "Identity", "DiagonalReplicated", diff --git a/scico/linop/_grad.py b/scico/linop/_grad.py new file mode 100644 index 000000000..416cc125d --- /dev/null +++ b/scico/linop/_grad.py @@ -0,0 +1,526 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2021-2024 by SCICO Developers +# All rights reserved. BSD 3-clause License. +# This file is part of the SCICO package. Details of the copyright and +# user license can be found in the 'LICENSE' file distributed with the +# package. + +"""Non-Cartesian gradient linear operators.""" + + +# Needed to annotate a class method that returns the encapsulating class +# see https://www.python.org/dev/peps/pep-0563/ +from __future__ import annotations + +from typing import Optional, Sequence, Tuple, Union + +import numpy as np + +import scico.numpy as snp +from scico.numpy import Array, BlockArray +from scico.typing import BlockShape, DType, Shape + +from ._linop import LinearOperator + + +def diffstack(x: Array, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: + """Compute the discrete difference along multiple axes. + + Apply :func:`snp.diff` along multiple axes, stacking the results on + a newly inserted axis at index 0. The `append` parameter of + :func:`snp.diff` is exploited to give output of the same length as + the input, which is achieved by zero-padding the output at the end + of each axis. + + + """ + if axis is None: + axis = tuple(range(x.ndim)) + elif isinstance(axis, int): + axis = (axis,) + dstack = [ + snp.diff( + x, + axis=ax, + append=x[tuple(slice(-1, None) if i == ax else slice(None) for i in range(x.ndim))], + ) + for ax in axis + ] + return snp.stack(dstack) + + +class ProjectedGradient(LinearOperator): + """Gradient projected onto local coordinate system. + + This class represents a linear operator that computes gradients of + arrays projected onto a local coordinate system that may differ at + every position in the array, as described in + :cite:`hossein-2024-total`. In the 2D illustration below :math:`x` + and :math:`y` represent the standard coordinate system defined by the + array axes, :math:`(g_x, g_y)` is the gradient vector within that + coordinate system, :math:`x'` and :math:`y'` are the local coordinate + axes, and :math:`(g_x', g_y')` is the gradient vector within the + local coordinate system. + + .. image:: /figures/projgrad.svg + :align: center + :alt: Figure illustrating projection of gradient onto local + coordinate system. + + Each of the local coordinate axes (e.g. :math:`x'` and :math:`y'` in + the illustration above) is represented by a separate array in the + `coord` tuple of arrays parameter of the class initializer. + + .. note:: + + This operator should not be confused with the Projected Gradient + optimization algorithm (a special case of Proximal Gradient), with + which it is unrelated. + """ + + def __init__( + self, + input_shape: Shape, + axes: Optional[Tuple[int, ...]] = None, + coord: Optional[Sequence[Union[Array, BlockArray]]] = None, + cdiff: bool = False, + input_dtype: DType = np.float32, + jit: bool = True, + ): + r""" + Args: + input_shape: Shape of input array. + axes: Axes over which to compute the gradient. Defaults to + ``None``, in which case the gradient is computed along + all axes. + coord: A tuple of arrays, each of which specifies a local + coordinate axis direction. Each member of the tuple + should either be a :class:`jax.Array` or a + :class:`.BlockArray`. If it is the former, it should have + shape :math:`N \times M_0 \times M_1 \times \ldots`, + where :math:`N` is the number of axes specified by + parameter `axes`, and :math:`M_i` is the size of the + :math:`i^{\mrm{th}}` axis. If it is the latter, it should + consist of :math:`N` blocks, each of which has a shape + that is suitable for multiplication with an array of + shape :math:`M_0 \times M_1 \times \ldots`. If `coord` is + a singleton tuple, the result of applying the operator is + a :class:`jax.Array`; otherwise it consists of the + gradients for each of the local coordinate axes in + `coord` stacked into a :class:`.BlockArray`. If `coord` + is ``None``, which is the default, gradients are computed + in the standard axis-aligned coordinate system, and the + return type depends on the number of axes on which the + gradient is calculated, as specified explicitly or + implicitly via the `axes` parameter. + cdiff: If ``True``, estimate gradients using the second order + central different returned by :func:`snp.gradient`, + otherwise use the first order asymmetric difference + returned by :func:`snp.diff`. + input_dtype: `dtype` for input argument. Default is + :attr:`~numpy.float32`. + jit: If ``True``, jit the evaluation, adjoint, and gram + functions of the LinearOperator. + """ + if axes is None: + # If axes is None, set it to all axes in input shape. + self.axes = tuple(range(len(input_shape))) + else: + # Ensure no invalid axis indices specified. + if snp.any(np.array(axes) >= len(input_shape)): + raise ValueError( + "Invalid axes specified; all elements of `axes` must be less than " + f"len(input_shape)={len(input_shape)}." + ) + self.axes = axes + output_shape: Union[Shape, BlockShape] + if coord is None: + # If coord is None, output shape is determined by number of axes. + if len(self.axes) == 1: + output_shape = input_shape + else: + output_shape = (input_shape,) * len(self.axes) + else: + # If coord is not None, output shape is determined by number of coord arrays. + if len(coord) == 1: + output_shape = input_shape + else: + output_shape = (input_shape,) * len(coord) + self.coord = coord + self.cdiff = cdiff + super().__init__( + input_shape=input_shape, + output_shape=output_shape, + input_dtype=input_dtype, + output_dtype=input_dtype, + jit=jit, + ) + + def _eval(self, x: Array) -> Union[Array, BlockArray]: + + if self.cdiff: + grad = snp.gradient(x, axis=self.axes) + else: + grad = diffstack(x, axis=self.axes) + if self.coord is None: + # If coord attribute is None, just return gradients on specified axes. + if len(self.axes) == 1: + return grad + else: + return snp.blockarray(grad) + else: + # If coord attribute is not None, return gradients projected onto specified local + # coordinate systems. + projgrad = [sum([c[m] * grad[m] for m in range(len(grad))]) for c in self.coord] + if len(self.coord) == 1: + return projgrad[0] + else: + return snp.blockarray(projgrad) + + +class PolarGradient(ProjectedGradient): + """Gradient projected into polar coordinates. + + Compute gradients projected onto angular and/or radial axis + directions, as described in :cite:`hossein-2024-total`. Local + coordinate axes are illustrated in the figure below. + + .. plot:: figures/polargrad.py + :align: center + :include-source: False + :show-source-link: False + + | + + If only one of `angular` and `radial` is ``True``, the operator + output is a :class:`jax.Array`, otherwise it is a + :class:`.BlockArray`. + """ + + def __init__( + self, + input_shape: Shape, + axes: Optional[Tuple[int, ...]] = None, + center: Optional[Union[Tuple[int, ...], Array]] = None, + angular: bool = True, + radial: bool = True, + cdiff: bool = False, + input_dtype: DType = np.float32, + jit: bool = True, + ): + r""" + Args: + input_shape: Shape of input array. + axes: Axes over which to compute the gradient. Should be a + tuple :math:`(i_x, i_y)`, where :math:`i_x` and + :math:`i_y` are input array axes assigned to :math:`x` + and :math:`y` coordinates respectively. Defaults to + ``None``, in which case the axes are taken to be `(0, 1)`. + center: Center of the polar coordinate system in array + indexing coordinates. Default is ``None``, which places + the center at the center of the input array. + angular: Flag indicating whether to compute gradients in the + angular (i.e. tangent to circles) direction. + radial: Flag indicating whether to compute gradients in the + radial (i.e. directed outwards from the origin) direction. + cdiff: If ``True``, estimate gradients using the second order + central different returned by :func:`snp.gradient`, + otherwise use the first order asymmetric difference + returned by :func:`snp.diff`. + input_dtype: `dtype` for input argument. Default is + :attr:`~numpy.float32`. + jit: If ``True``, jit the evaluation, adjoint, and gram + functions of the LinearOperator. + """ + + if len(input_shape) < 2: + raise ValueError("Invalid input shape; input must have at least two axes.") + if axes is not None and len(axes) != 2: + raise ValueError("Invalid axes specified; exactly two axes must be specified.") + if not angular and not radial: + raise ValueError("At least one of angular and radial must be True.") + + real_input_dtype = snp.util.real_dtype(input_dtype) + if axes is None: + axes = (0, 1) + axes_shape = [input_shape[ax] for ax in axes] + if center is None: + center = (snp.array(axes_shape, dtype=real_input_dtype) - 1) / 2 + else: + if isinstance(center, (tuple, list)): + center = snp.array(center) + center = center.astype(real_input_dtype) + end = snp.array(axes_shape, dtype=real_input_dtype) - center + g0, g1 = snp.ogrid[-center[0] : end[0], -center[1] : end[1]] + theta = snp.arctan2(g0, g1) + # Re-order theta axes in case indices in axes parameter are not in increasing order. + axis_order = np.argsort(axes) + theta = snp.transpose(theta, axis_order) + if len(input_shape) > 2: + # Construct list of input axes that are not included in the gradient axes. + single = tuple(set(range(len(input_shape))) - set(axes)) + # Insert singleton axes to align theta for multiplication with gradients. + theta = snp.expand_dims(theta, single) + coord = [] + if angular: + coord.append(snp.blockarray([-snp.cos(theta), snp.sin(theta)])) + if radial: + coord.append(snp.blockarray([snp.sin(theta), snp.cos(theta)])) + super().__init__( + input_shape=input_shape, + input_dtype=input_dtype, + axes=axes, + coord=coord, + cdiff=cdiff, + jit=jit, + ) + + +class CylindricalGradient(ProjectedGradient): + """Gradient projected into cylindrical coordinates. + + Compute gradients projected onto cylindrical coordinate axes, as + described in :cite:`hossein-2024-total`. The local coordinate axes + are illustrated in the figure below. + + .. plot:: figures/cylindgrad.py + :align: center + :include-source: False + :show-source-link: False + + | + + If only one of `angular`, `radial`, and `axial` is ``True``, the + operator output is a :class:`jax.Array`, otherwise it is a + :class:`.BlockArray`. + """ + + def __init__( + self, + input_shape: Shape, + axes: Optional[Tuple[int, ...]] = None, + center: Optional[Union[Tuple[int, ...], Array]] = None, + angular: bool = True, + radial: bool = True, + axial: bool = True, + cdiff: bool = False, + input_dtype: DType = np.float32, + jit: bool = True, + ): + r""" + Args: + input_shape: Shape of input array. + axes: Axes over which to compute the gradient. Should be a + tuple :math:`(i_x, i_y, i_z)`, where :math:`i_x`, + :math:`i_y` and :math:`i_z` are input array axes assigned + to :math:`x`, :math:`y`, and :math:`z` coordinates + respectively. Defaults to ``None``, in which case the + axes are taken to be `(0, 1, 2)`. If an integer, this + operator returns a :class:`jax.Array`. If a tuple or + ``None``, the resulting arrays are stacked into a + :class:`.BlockArray`. + center: Center of the cylindrical coordinate system in array + indexing coordinates. Default is ``None``, which places + the center at the center of the two polar axes of the + input array and at the zero index of the axial axis. + angular: Flag indicating whether to compute gradients in the + angular (i.e. tangent to circles) direction. + radial: Flag indicating whether to compute gradients in the + radial (i.e. directed outwards from the origin) direction. + axial: Flag indicating whether to compute gradients in the + direction of the axis of the cylinder. + cdiff: If ``True``, estimate gradients using the second order + central different returned by :func:`snp.gradient`, + otherwise use the first order asymmetric difference + returned by :func:`snp.diff`. + input_dtype: `dtype` for input argument. Default is + :attr:`~numpy.float32`. + jit: If ``True``, jit the evaluation, adjoint, and gram + functions of the LinearOperator. + """ + + if len(input_shape) < 3: + raise ValueError("Invalid input shape; input must have at least three axes.") + if axes is not None and len(axes) != 3: + raise ValueError("Invalid axes specified; exactly three axes must be specified.") + if not angular and not radial and not axial: + raise ValueError("At least one of angular, radial, and axial must be True.") + + real_input_dtype = snp.util.real_dtype(input_dtype) + if axes is None: + axes = (0, 1, 2) + axes_shape = [input_shape[ax] for ax in axes] + if center is None: + center = (snp.array(axes_shape, dtype=real_input_dtype) - 1) / 2 + center = center.at[-1].set(0) # type: ignore + else: + if isinstance(center, (tuple, list)): + center = snp.array(center) + center = center.astype(real_input_dtype) + end = snp.array(axes_shape, dtype=real_input_dtype) - center + g0, g1 = snp.ogrid[-center[0] : end[0], -center[1] : end[1]] + g0 = g0[..., np.newaxis] + g1 = g1[..., np.newaxis] + theta = snp.arctan2(g0, g1) + # Re-order theta axes in case indices in axes parameter are not in increasing order. + axis_order = np.argsort(axes) + theta = snp.transpose(theta, axis_order) + if len(input_shape) > 3: + # Construct list of input axes that are not included in the gradient axes. + single = tuple(set(range(len(input_shape))) - set(axes)) + # Insert singleton axes to align theta for multiplication with gradients. + theta = snp.expand_dims(theta, single) + coord = [] + if angular: + coord.append( + snp.blockarray( + [-snp.cos(theta), snp.sin(theta), snp.array([0.0], dtype=real_input_dtype)] + ) + ) + if radial: + coord.append( + snp.blockarray( + [snp.sin(theta), snp.cos(theta), snp.array([0.0], dtype=real_input_dtype)] + ) + ) + if axial: + coord.append( + snp.blockarray( + [ + snp.array([0.0], dtype=real_input_dtype), + snp.array([0.0], dtype=real_input_dtype), + snp.array([1.0], dtype=real_input_dtype), + ] + ) + ) + super().__init__( + input_shape=input_shape, + input_dtype=input_dtype, + axes=axes, + cdiff=cdiff, + coord=coord, + jit=jit, + ) + + +class SphericalGradient(ProjectedGradient): + """Gradient projected into spherical coordinates. + + Compute gradients projected onto spherical coordinate axes, based on + the approach described in :cite:`hossein-2024-total`. The local + coordinate axes are illustrated in the figure below. + + .. plot:: figures/spheregrad.py + :align: center + :include-source: False + :show-source-link: False + + | + + If only one of `azimuthal`, `polar`, and `radial` is ``True``, the + operator output is a :class:`jax.Array`, otherwise it is a + :class:`.BlockArray`. + """ + + def __init__( + self, + input_shape: Shape, + axes: Optional[Tuple[int, ...]] = None, + center: Optional[Union[Tuple[int, ...], Array]] = None, + azimuthal: bool = True, + polar: bool = True, + radial: bool = True, + cdiff: bool = False, + input_dtype: DType = np.float32, + jit: bool = True, + ): + r""" + Args: + input_shape: Shape of input array. + axes: Axes over which to compute the gradient. Should be a + tuple :math:`(i_x, i_y, i_z)`, where :math:`i_x`, + :math:`i_y` and :math:`i_z` are input array axes assigned + to :math:`x`, :math:`y`, and :math:`z` coordinates + respectively. Defaults to ``None``, in which case the + axes are taken to be `(0, 1, 2)`. If an integer, this + operator returns a :class:`jax.Array`. If a tuple or + ``None``, the resulting arrays are stacked into a + :class:`.BlockArray`. + center: Center of the spherical coordinate system in array + indexing coordinates. Default is ``None``, which places + the center at the center of the input array. + azimuthal: Flag indicating whether to compute gradients in + the azimuthal direction. + polar: Flag indicating whether to compute gradients in the + polar direction. + radial: Flag indicating whether to compute gradients in the + radial direction. + cdiff: If ``True``, estimate gradients using the second order + central different returned by :func:`snp.gradient`, + otherwise use the first order asymmetric difference + returned by :func:`snp.diff`. + input_dtype: `dtype` for input argument. Default is + :attr:`~numpy.float32`. + jit: If ``True``, jit the evaluation, adjoint, and gram + functions of the LinearOperator. + """ + + if len(input_shape) < 3: + raise ValueError("Invalid input shape; input must have at least three axes.") + if axes is not None and len(axes) != 3: + raise ValueError("Invalid axes specified; exactly three axes must be specified.") + if not azimuthal and not polar and not radial: + raise ValueError("At least one of azimuthal, polar, and radial must be True.") + + real_input_dtype = snp.util.real_dtype(input_dtype) + if axes is None: + axes = (0, 1, 2) + axes_shape = [input_shape[ax] for ax in axes] + if center is None: + center = (snp.array(axes_shape, dtype=real_input_dtype) - 1) / 2 + else: + if isinstance(center, (tuple, list)): + center = snp.array(center) + center = center.astype(real_input_dtype) + end = snp.array(axes_shape, dtype=real_input_dtype) - center + g0, g1, g2 = snp.ogrid[-center[0] : end[0], -center[1] : end[1], -center[2] : end[2]] + theta = snp.arctan2(g1, g0) + phi = snp.arctan2(snp.sqrt(g0**2 + g1**2), g2) + # Re-order theta and phi axes in case indices in axes parameter are not in + # increasing order. + axis_order = np.argsort(axes) + theta = snp.transpose(theta, axis_order) + phi = snp.transpose(phi, axis_order) + if len(input_shape) > 3: + # Construct list of input axes that are not included in the gradient axes. + single = tuple(set(range(len(input_shape))) - set(axes)) + # Insert singleton axes to align theta for multiplication with gradients. + theta = snp.expand_dims(theta, single) + phi = snp.expand_dims(phi, single) + coord = [] + if azimuthal: + coord.append( + snp.blockarray( + [snp.sin(theta), -snp.cos(theta), snp.array([0.0], dtype=real_input_dtype)] + ) + ) + if polar: + coord.append( + snp.blockarray( + [snp.cos(phi) * snp.cos(theta), snp.cos(phi) * snp.sin(theta), -snp.sin(phi)] + ) + ) + if radial: + coord.append( + snp.blockarray( + [snp.sin(phi) * snp.cos(theta), snp.sin(phi) * snp.sin(theta), snp.cos(phi)] + ) + ) + super().__init__( + input_shape=input_shape, + input_dtype=input_dtype, + axes=axes, + coord=coord, + cdiff=cdiff, + jit=jit, + ) diff --git a/scico/optimize/_pgm.py b/scico/optimize/_pgm.py index 8375c0b2d..6188f7480 100644 --- a/scico/optimize/_pgm.py +++ b/scico/optimize/_pgm.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2020-2023 by SCICO Developers +# Copyright (C) 2020-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -30,14 +30,35 @@ class PGM(Optimizer): - r"""Proximal Gradient Method (PGM) base class. + r"""Proximal gradient method (PGM) algorithm. - Minimize a function of the form :math:`f(\mb{x}) + g(\mb{x})`, where - :math:`f` and the :math:`g` are instances of :class:`.Functional`. + Minimize a functional of the form :math:`f(\mb{x}) + g(\mb{x})`, + where :math:`f` and the :math:`g` are instances of + :class:`.Functional`. Functional :math:`f` should be differentiable + and have a Lipschitz continuous derivative, and functional :math:`g` + should have a proximal operator defined. + + The step size :math:`\alpha` of the algorithm is defined in terms of + its reciprocal :math:`L`, i.e. :math:`\alpha = 1 / L`. The initial + value for this parameter, `L0`, is required to satisfy + + .. math:: + L_0 \geq K(\nabla f) \;, - Uses helper :class:`StepSize` to provide an estimate of the Lipschitz - constant :math:`L` of :math:`f`. The step size :math:`\alpha` is the - reciprocal of :math:`L`, i.e.: :math:`\alpha = 1 / L`. + where :math:`K(\nabla f)` denotes the Lipschitz constant of the + gradient of :math:`f`. When `f` is an instance of + :class:`.SquaredL2Loss` with a :class:`.LinearOperator` `A`, + + .. math:: + K(\nabla f) = \lambda_{ \mathrm{max} }( A^H A ) = \| A \|_2^2 \;, + + where :math:`\lambda_{\mathrm{max}}(B)` denotes the largest + eigenvalue of :math:`B`. + + The evolution of the step size is controlled by auxiliary class + :class:`.PGMStepSize` and derived classes. The default + :class:`.PGMStepSize` simply sets :math:`L = L_0`, while the derived + classes implement a variety of adaptive strategies. """ def __init__( @@ -52,12 +73,14 @@ def __init__( r""" Args: - f: Loss or Functional object with `grad` defined. - g: Instance of Functional with defined prox method. - L0: Initial estimate of Lipschitz constant of f. + f: Instance of :class:`.Loss` or :class:`.Functional` with + defined `grad` method. + g: Instance of :class:`.Functional` with defined prox method. + L0: Initial estimate of Lipschitz constant of gradient of `f`. x0: Starting point for :math:`\mb{x}`. - step_size: helper :class:`StepSize` to estimate the Lipschitz - constant of f. + step_size: Instance of an auxiliary class of type + :class:`.PGMStepSize` determining the evolution of the + algorithm step size. **kwargs: Additional optional parameters handled by initializer of base class :class:`.Optimizer`. """ @@ -75,7 +98,7 @@ def __init__( step_size = PGMStepSize() self.step_size: PGMStepSize = step_size self.step_size.internal_init(self) - self.L: float = L0 # reciprocal of step size (estimate of Lipschitz constant of f) + self.L: float = L0 # reciprocal of step size (estimate of Lipschitz constant of ∇f) self.fixed_point_residual = snp.inf def x_step(v: Union[Array, BlockArray], L: float) -> Union[Array, BlockArray]: @@ -150,16 +173,14 @@ def step(self): class AcceleratedPGM(PGM): - r"""Accelerated Proximal Gradient Method (AcceleratedPGM) base class. - - Minimize a function of the form :math:`f(\mb{x}) + g(\mb{x})`. + r"""Accelerated proximal gradient method (APGM) algorithm. Minimize a function of the form :math:`f(\mb{x}) + g(\mb{x})`, where :math:`f` and the :math:`g` are instances of :class:`.Functional`. The accelerated form of PGM is also known as FISTA :cite:`beck-2009-fast`. - For documentation on inherited attributes, see :class:`.PGM`. + See :class:`.PGM` for more detailed documentation. """ def __init__( @@ -173,12 +194,14 @@ def __init__( ): r""" Args: - f: Loss or Functional object with `grad` defined. - g: Instance of Functional with defined prox method. - L0: Initial estimate of Lipschitz constant of f. + f: Instance of :class:`.Loss` or :class:`.Functional` with + defined `grad` method. + g: Instance of :class:`.Functional` with defined prox method. + L0: Initial estimate of Lipschitz constant of gradient of `f`. x0: Starting point for :math:`\mb{x}`. - step_size: helper :class:`StepSize` to estimate the Lipschitz - constant of f. + step_size: Instance of an auxiliary class of type + :class:`.PGMStepSize` determining the evolution of the + algorithm step size. **kwargs: Additional optional parameters handled by initializer of base class :class:`.Optimizer`. """ diff --git a/scico/plot.py b/scico/plot.py index 1cdd1adc7..6c0375be1 100644 --- a/scico/plot.py +++ b/scico/plot.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2020-2022 by SCICO Developers +# Copyright (C) 2020-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -360,7 +360,7 @@ def surf( # https://stackoverflow.com/a/35221116 if ax.name != "3d": ax.remove() - ax = fig.add_subplot(*ax.get_geometry(), projection="3d") + ax = fig.add_subplot(ax.get_subplotspec(), projection="3d") if elev is not None or azim is not None: ax.view_init(elev=elev, azim=azim) diff --git a/scico/test/linop/test_grad.py b/scico/test/linop/test_grad.py new file mode 100644 index 000000000..19687519c --- /dev/null +++ b/scico/test/linop/test_grad.py @@ -0,0 +1,218 @@ +from itertools import combinations + +import numpy as np + +import jax + +import pytest + +import scico.numpy as snp +from scico.linop import CylindricalGradient, PolarGradient, SphericalGradient +from scico.numpy import Array, BlockArray +from scico.random import randn + + +class TestPolarGradient: + def setup_method(self, method): + self.key = jax.random.PRNGKey(12345) + + @pytest.mark.parametrize("jit", [True, False]) + @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) + @pytest.mark.parametrize("outflags", [(True, True), (True, False), (False, True)]) + @pytest.mark.parametrize("center", [None, (-2, 3), (1.2, -3.5)]) + @pytest.mark.parametrize( + "shape_axes", + [ + ((20, 20), None), + ((20, 21), (0, 1)), + ((16, 17, 3), (0, 1)), + ((2, 17, 16), (1, 2)), + ((2, 17, 16, 3), (2, 1)), + ], + ) + @pytest.mark.parametrize("cdiff", [True, False]) + def test_eval(self, cdiff, shape_axes, center, outflags, input_dtype, jit): + + input_shape, axes = shape_axes + if axes is None: + testaxes = (0, 1) + else: + testaxes = axes + if center is not None: + axes_shape = [input_shape[ax] for ax in testaxes] + center = (snp.array(axes_shape) - 1) / 2 + snp.array(center) + angular, radial = outflags + x, key = randn(input_shape, dtype=input_dtype, key=self.key) + A = PolarGradient( + input_shape, + axes=axes, + center=center, + angular=angular, + radial=radial, + cdiff=cdiff, + input_dtype=input_dtype, + jit=jit, + ) + Ax = A @ x + if angular and radial: + assert isinstance(Ax, BlockArray) + assert len(Ax.shape) == 2 + assert Ax[0].shape == input_shape + assert Ax[1].shape == input_shape + else: + assert isinstance(Ax, Array) + assert Ax.shape == input_shape + assert Ax.dtype == input_dtype + + # Test orthogonality of coordinate axes + coord = A.coord + for n0, n1 in combinations(range(len(coord)), 2): + c0 = coord[n0] + c1 = coord[n1] + assert snp.abs(snp.sum(c0 * c1)) < 1e-5 + + +class TestCylindricalGradient: + def setup_method(self, method): + self.key = jax.random.PRNGKey(12345) + + @pytest.mark.parametrize("jit", [True, False]) + @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) + @pytest.mark.parametrize( + "outflags", + [ + (True, True, True), + (True, True, False), + (True, False, True), + (True, False, False), + (False, True, True), + (False, True, False), + (False, False, True), + ], + ) + @pytest.mark.parametrize("center", [None, (-2, 3, 0), (1.2, -3.5, 1.5)]) + @pytest.mark.parametrize( + "shape_axes", + [ + ((20, 20, 20), None), + ((17, 18, 19), (0, 1, 2)), + ((16, 17, 18, 3), (0, 1, 2)), + ((2, 17, 16, 15), (1, 2, 3)), + ((17, 2, 16, 15), (0, 2, 3)), + ((17, 2, 16, 15), (3, 2, 0)), + ], + ) + def test_eval(self, shape_axes, center, outflags, input_dtype, jit): + + input_shape, axes = shape_axes + if axes is None: + testaxes = (0, 1, 2) + else: + testaxes = axes + if center is not None: + axes_shape = [input_shape[ax] for ax in testaxes] + center = (snp.array(axes_shape) - 1) / 2 + snp.array(center) + angular, radial, axial = outflags + x, key = randn(input_shape, dtype=input_dtype, key=self.key) + A = CylindricalGradient( + input_shape, + axes=axes, + center=center, + angular=angular, + radial=radial, + axial=axial, + input_dtype=input_dtype, + jit=jit, + ) + Ax = A @ x + Nc = sum([angular, radial, axial]) + if Nc > 1: + assert isinstance(Ax, BlockArray) + assert len(Ax) == Nc + for n in range(Nc): + assert Ax[n].shape == input_shape + else: + assert isinstance(Ax, Array) + assert Ax.shape == input_shape + assert Ax.dtype == input_dtype + + # Test orthogonality of coordinate axes + coord = A.coord + for n0, n1 in combinations(range(len(coord)), 2): + c0 = coord[n0] + c1 = coord[n1] + s = sum([c0[m] * c1[m] for m in range(len(c0))]).sum() + assert snp.abs(s) < 1e-5 + + +class TestSphericalGradient: + def setup_method(self, method): + self.key = jax.random.PRNGKey(12345) + + @pytest.mark.parametrize("jit", [True, False]) + @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) + @pytest.mark.parametrize( + "outflags", + [ + (True, True, True), + (True, True, False), + (True, False, True), + (True, False, False), + (False, True, True), + (False, True, False), + (False, False, True), + ], + ) + @pytest.mark.parametrize("center", [None, (-2, 3, 0), (1.2, -3.5, 1.5)]) + @pytest.mark.parametrize( + "shape_axes", + [ + ((20, 20, 20), None), + ((17, 18, 19), (0, 1, 2)), + ((16, 17, 18, 3), (0, 1, 2)), + ((2, 17, 16, 15), (1, 2, 3)), + ((17, 2, 16, 15), (0, 2, 3)), + ((17, 2, 16, 15), (3, 2, 0)), + ], + ) + def test_eval(self, shape_axes, center, outflags, input_dtype, jit): + + input_shape, axes = shape_axes + if axes is None: + testaxes = (0, 1, 2) + else: + testaxes = axes + if center is not None: + axes_shape = [input_shape[ax] for ax in testaxes] + center = (snp.array(axes_shape) - 1) / 2 + snp.array(center) + azimuthal, polar, radial = outflags + x, key = randn(input_shape, dtype=input_dtype, key=self.key) + A = SphericalGradient( + input_shape, + axes=axes, + center=center, + azimuthal=azimuthal, + polar=polar, + radial=radial, + input_dtype=input_dtype, + jit=jit, + ) + Ax = A @ x + Nc = sum([azimuthal, polar, radial]) + if Nc > 1: + assert isinstance(Ax, BlockArray) + assert len(Ax) == Nc + for n in range(Nc): + assert Ax[n].shape == input_shape + else: + assert isinstance(Ax, Array) + assert Ax.shape == input_shape + assert Ax.dtype == input_dtype + + # Test orthogonality of coordinate axes + coord = A.coord + for n0, n1 in combinations(range(len(coord)), 2): + c0 = coord[n0] + c1 = coord[n1] + s = sum([c0[m] * c1[m] for m in range(len(c0))]).sum() + assert snp.abs(s) < 1e-5