diff --git a/CHANGES.rst b/CHANGES.rst index e5aa718b..eb4c3b97 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -20,7 +20,7 @@ Version 0.0.6 (unreleased) • Rename ``scico.flax.save_weights`` and ``scico.flax.load_weights`` to ``scico.flax.save_variables`` and ``scico.flax.load_variables`` respectively. -• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.33. +• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.34. • Support ``flax`` versions 0.8.0 to 0.9.0. diff --git a/data b/data index c1233896..1ceadbbb 160000 --- a/data +++ b/data @@ -1 +1 @@ -Subproject commit c12338966b1b9f92554066743b1a8b664c7b7e24 +Subproject commit 1ceadbbbe6bef9f364fd76dbda44ff6e185a7d10 diff --git a/docs/source/examples.rst b/docs/source/examples.rst index 58ba847f..b97a52bd 100644 --- a/docs/source/examples.rst +++ b/docs/source/examples.rst @@ -34,12 +34,11 @@ Computed Tomography examples/ct_svmbir_ppp_bm3d_admm_cg examples/ct_svmbir_ppp_bm3d_admm_prox examples/ct_fan_svmbir_ppp_bm3d_admm_prox - examples/ct_astra_modl_train_foam2 - examples/ct_astra_odp_train_foam2 - examples/ct_astra_unet_train_foam2 + examples/ct_modl_train_foam2 + examples/ct_odp_train_foam2 + examples/ct_unet_train_foam2 examples/ct_projector_comparison_2d examples/ct_projector_comparison_3d - examples/ct_multi_cs_tv_admm examples/ct_multi_tv_admm Deconvolution @@ -96,7 +95,7 @@ Miscellaneous examples/denoise_dncnn_universal examples/diffusercam_tv_admm examples/video_rpca_admm - examples/ct_astra_datagen_foam2 + examples/ct_datagen_foam2 examples/deconv_datagen_bsds examples/deconv_datagen_foam1 examples/denoise_datagen_bsds @@ -181,10 +180,10 @@ Machine Learning .. toctree:: :maxdepth: 1 - examples/ct_astra_datagen_foam2 - examples/ct_astra_modl_train_foam2 - examples/ct_astra_odp_train_foam2 - examples/ct_astra_unet_train_foam2 + examples/ct_datagen_foam2 + examples/ct_modl_train_foam2 + examples/ct_odp_train_foam2 + examples/ct_unet_train_foam2 examples/deconv_datagen_bsds examples/deconv_datagen_foam1 examples/deconv_modl_train_foam1 diff --git a/examples/jnb.py b/examples/jnb.py index 9135add1..f51c3f9a 100644 --- a/examples/jnb.py +++ b/examples/jnb.py @@ -60,10 +60,10 @@ def py_file_to_string(src): # Process remainder of source file for line in srcfile: - if re.match("^input\(", line): # end processing when input statement encountered + if re.match(r"^input\(", line): # end processing when input statement encountered break line = re.sub('^r"""', '"""', line) # remove r from r""" - line = re.sub(":cite:\`([^`]+)\`", r'', line) # fix cite format + line = re.sub(r":cite:\`([^`]+)\`", r'', line) # fix cite format lines.append(line) # Backtrack through list of lines to remove trailing newlines diff --git a/examples/scripts/README.rst b/examples/scripts/README.rst index 446186a7..95e11f7c 100644 --- a/examples/scripts/README.rst +++ b/examples/scripts/README.rst @@ -33,11 +33,11 @@ Computed Tomography PPP (with BM3D) CT Reconstruction (ADMM with Fast SVMBIR Prox) `ct_fan_svmbir_ppp_bm3d_admm_prox.py `_ PPP (with BM3D) Fan-Beam CT Reconstruction - `ct_astra_modl_train_foam2.py `_ - CT Training and Reconstructions with MoDL - `ct_astra_odp_train_foam2.py `_ - CT Training and Reconstructions with ODP - `ct_astra_unet_train_foam2.py `_ + `ct_modl_train_foam2.py `_ + CT Training and Reconstruction with MoDL + `ct_odp_train_foam2.py `_ + CT Training and Reconstruction with ODP + `ct_unet_train_foam2.py `_ CT Training and Reconstructions with UNet `ct_projector_comparison_2d.py `_ 2D X-ray Transform Comparison @@ -123,7 +123,7 @@ Miscellaneous TV-Regularized 3D DiffuserCam Reconstruction `video_rpca_admm.py `_ Video Decomposition via Robust PCA - `ct_astra_datagen_foam2.py `_ + `ct_datagen_foam2.py `_ CT Data Generation for NN Training `deconv_datagen_bsds.py `_ Blurred Data Generation (Natural Images) for NN Training @@ -239,13 +239,13 @@ Sparsity Machine Learning ^^^^^^^^^^^^^^^^ - `ct_astra_datagen_foam2.py `_ + `ct_datagen_foam2.py `_ CT Data Generation for NN Training - `ct_astra_modl_train_foam2.py `_ - CT Training and Reconstructions with MoDL - `ct_astra_odp_train_foam2.py `_ - CT Training and Reconstructions with ODP - `ct_astra_unet_train_foam2.py `_ + `ct_modl_train_foam2.py `_ + CT Training and Reconstruction with MoDL + `ct_odp_train_foam2.py `_ + CT Training and Reconstruction with ODP + `ct_unet_train_foam2.py `_ CT Training and Reconstructions with UNet `deconv_datagen_bsds.py `_ Blurred Data Generation (Natural Images) for NN Training diff --git a/examples/scripts/ct_astra_datagen_foam2.py b/examples/scripts/ct_datagen_foam2.py similarity index 93% rename from examples/scripts/ct_astra_datagen_foam2.py rename to examples/scripts/ct_datagen_foam2.py index 4e6fb97c..2fc9be59 100644 --- a/examples/scripts/ct_astra_datagen_foam2.py +++ b/examples/scripts/ct_datagen_foam2.py @@ -14,6 +14,7 @@ """ # isort: off +import os import numpy as np import logging @@ -21,6 +22,9 @@ ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087 +# Set an arbitrary processor count (only applies if GPU is not available). +os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" + from scico import plot from scico.flax.examples import load_ct_data diff --git a/examples/scripts/ct_astra_modl_train_foam2.py b/examples/scripts/ct_modl_train_foam2.py similarity index 96% rename from examples/scripts/ct_astra_modl_train_foam2.py rename to examples/scripts/ct_modl_train_foam2.py index a06d7b81..19a3d810 100644 --- a/examples/scripts/ct_astra_modl_train_foam2.py +++ b/examples/scripts/ct_modl_train_foam2.py @@ -5,8 +5,8 @@ # with the package. r""" -CT Training and Reconstructions with MoDL -========================================= +CT Training and Reconstruction with MoDL +======================================== This example demonstrates the training and application of a model-based deep learning (MoDL) architecture described in @@ -65,7 +65,7 @@ from scico import metric, plot from scico.flax.examples import load_ct_data from scico.flax.train.traversals import clip_positive, construct_traversal -from scico.linop.xray.astra import XRayTransform2D +from scico.linop.xray import XRayTransform2D """ Prepare parallel processing. Set an arbitrary processor count (only @@ -89,16 +89,17 @@ """ -Build CT projection operator. +Build CT projection operator. Parameters are chosen so that the operator +is equivalent to the one used to generate the training data. """ angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles A = XRayTransform2D( input_shape=(N, N), - det_spacing=1, - det_count=N, angles=angles, -) # CT projection operator -A = (1.0 / N) * A # normalized + det_count=int(N * 1.05 / np.sqrt(2.0)), + dx=1.0 / np.sqrt(2), +) +A = (1.0 / N) * A # normalize projection operator """ diff --git a/examples/scripts/ct_multi_tv_admm.py b/examples/scripts/ct_multi_tv_admm.py index 95711679..f2f13fd8 100644 --- a/examples/scripts/ct_multi_tv_admm.py +++ b/examples/scripts/ct_multi_tv_admm.py @@ -38,7 +38,7 @@ np.random.seed(1234) x_gt = snp.array(discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)) -det_count = N +det_count = int(N * 1.05 / np.sqrt(2.0)) det_spacing = np.sqrt(2) diff --git a/examples/scripts/ct_astra_odp_train_foam2.py b/examples/scripts/ct_odp_train_foam2.py similarity index 94% rename from examples/scripts/ct_astra_odp_train_foam2.py rename to examples/scripts/ct_odp_train_foam2.py index 03753c1b..e5cd58ae 100644 --- a/examples/scripts/ct_astra_odp_train_foam2.py +++ b/examples/scripts/ct_odp_train_foam2.py @@ -5,8 +5,8 @@ # with the package. r""" -CT Training and Reconstructions with ODP -======================================== +CT Training and Reconstruction with ODP +======================================= This example demonstrates the training of the unrolled optimization with deep priors (ODP) gradient descent architecture described in @@ -72,7 +72,7 @@ from scico import metric, plot from scico.flax.examples import load_ct_data from scico.flax.train.traversals import clip_positive, construct_traversal -from scico.linop.xray.astra import XRayTransform2D +from scico.linop.xray import XRayTransform2D platform = get_backend().platform @@ -92,21 +92,22 @@ """ -Build CT projection operator. +Build CT projection operator. Parameters are chosen so that the operator +is equivalent to the one used to generate the training data. """ angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles A = XRayTransform2D( input_shape=(N, N), - det_spacing=1, - det_count=N, angles=angles, -) # CT projection operator -A = (1.0 / N) * A # normalized + det_count=int(N * 1.05 / np.sqrt(2.0)), + dx=1.0 / np.sqrt(2), +) +A = (1.0 / N) * A # normalize projection operator """ Build training and testing structures. Inputs are the sinograms and -outpus are the original generated foams. Keep training and testing +outputs are the original generated foams. Keep training and testing partitions. """ numtr = 320 diff --git a/examples/scripts/ct_projector_comparison_2d.py b/examples/scripts/ct_projector_comparison_2d.py index 2e5d02d3..0a47c7b0 100644 --- a/examples/scripts/ct_projector_comparison_2d.py +++ b/examples/scripts/ct_projector_comparison_2d.py @@ -29,9 +29,6 @@ Create a ground truth image. """ N = 512 - -det_count = int(jnp.ceil(jnp.sqrt(2 * N**2))) - x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N) x_gt = jnp.array(x_gt) @@ -41,17 +38,18 @@ """ num_angles = 500 angles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False) +det_count = int(N * 1.02 / jnp.sqrt(2.0)) timer = Timer() projectors = {} timer.start("scico_init") -projectors["scico"] = XRayTransform2D((N, N), angles) +projectors["scico"] = XRayTransform2D((N, N), angles, det_count=det_count) timer.stop("scico_init") timer.start("astra_init") projectors["astra"] = astra.XRayTransform2D( - (N, N), det_count=det_count, det_spacing=1.0, angles=angles - jnp.pi / 2.0 + (N, N), det_count=det_count, det_spacing=np.sqrt(2), angles=angles - jnp.pi / 2.0 ) timer.stop("astra_init") diff --git a/examples/scripts/ct_astra_unet_train_foam2.py b/examples/scripts/ct_unet_train_foam2.py similarity index 100% rename from examples/scripts/ct_astra_unet_train_foam2.py rename to examples/scripts/ct_unet_train_foam2.py diff --git a/examples/scripts/index.rst b/examples/scripts/index.rst index 4f05ba2f..e36e9fbd 100644 --- a/examples/scripts/index.rst +++ b/examples/scripts/index.rst @@ -21,9 +21,9 @@ Computed Tomography - ct_svmbir_ppp_bm3d_admm_cg.py - ct_svmbir_ppp_bm3d_admm_prox.py - ct_fan_svmbir_ppp_bm3d_admm_prox.py - - ct_astra_modl_train_foam2.py - - ct_astra_odp_train_foam2.py - - ct_astra_unet_train_foam2.py + - ct_modl_train_foam2.py + - ct_odp_train_foam2.py + - ct_unet_train_foam2.py - ct_projector_comparison_2d.py - ct_projector_comparison_3d.py - ct_multi_tv_admm.py @@ -73,7 +73,7 @@ Miscellaneous - denoise_dncnn_universal.py - diffusercam_tv_admm.py - video_rpca_admm.py - - ct_astra_datagen_foam2.py + - ct_datagen_foam2.py - deconv_datagen_bsds.py - deconv_datagen_foam1.py - denoise_datagen_bsds.py @@ -143,10 +143,10 @@ Sparsity Machine Learning ^^^^^^^^^^^^^^^^ - - ct_astra_datagen_foam2.py - - ct_astra_modl_train_foam2.py - - ct_astra_odp_train_foam2.py - - ct_astra_unet_train_foam2.py + - ct_datagen_foam2.py + - ct_modl_train_foam2.py + - ct_odp_train_foam2.py + - ct_unet_train_foam2.py - deconv_datagen_bsds.py - deconv_datagen_foam1.py - deconv_modl_train_foam1.py diff --git a/misc/conda/install_conda.sh b/misc/conda/install_conda.sh index 73defcaa..744d1836 100755 --- a/misc/conda/install_conda.sh +++ b/misc/conda/install_conda.sh @@ -97,7 +97,6 @@ rm -f /tmp/miniconda.sh export PATH="$CONDAHOME/bin:$PATH" hash -r conda config --set always_yes yes -conda install mamba -n base -c conda-forge conda update -q conda conda info -a diff --git a/misc/conda/make_conda_env.sh b/misc/conda/make_conda_env.sh index b5aa4411..cab2b7e2 100755 --- a/misc/conda/make_conda_env.sh +++ b/misc/conda/make_conda_env.sh @@ -50,7 +50,7 @@ EOF ) # Requirements that cannot be installed via conda (i.e. have to use pip) NOCONDA=$(cat <<-EOF -flax bm3d bm4d py2jn colour_demosaicing hyperopt ray[tune,train] +flax orbax-checkpoint bm3d bm4d py2jn colour_demosaicing hyperopt ray[tune,train] EOF ) @@ -217,19 +217,16 @@ eval "$(conda shell.bash hook)" # required to avoid errors re: `conda init` conda activate $ENVNM # Q: why not `source activate`? A: not always in the path # Add conda-forge channel -conda config --env --append channels conda-forge - -# Install mamba -conda install mamba -n base -c conda-forge +conda config --append channels conda-forge # Install required conda packages (and extra useful packages) -mamba install $CONDA_FLAGS $CONDAREQ ipython +conda install $CONDA_FLAGS $CONDAREQ ipython # Utility ffmpeg is required by imageio for reading mp4 video files # it can also be installed via the system package manager, .e.g. # sudo apt install ffmpeg if [ "$(which ffmpeg)" = '' ]; then - mamba install $CONDA_FLAGS ffmpeg + conda install $CONDA_FLAGS ffmpeg fi # Install jaxlib and jax diff --git a/requirements.txt b/requirements.txt index 1b0e5359..feab852c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,8 +4,8 @@ scipy>=1.6.0 imageio>=2.17 tifffile matplotlib -jaxlib>=0.4.3,<=0.4.33 -jax>=0.4.3,<=0.4.33 +jaxlib>=0.4.3,<=0.4.34 +jax>=0.4.3,<=0.4.34 orbax-checkpoint>=0.5.0 flax>=0.8.0,<=0.9.0 pyabel>=0.9.0 diff --git a/scico/flax/examples/data_generation.py b/scico/flax/examples/data_generation.py index e7a99030..e89f10cf 100644 --- a/scico/flax/examples/data_generation.py +++ b/scico/flax/examples/data_generation.py @@ -51,16 +51,9 @@ class UnitCircle: from jax.lib.xla_bridge import get_backend from scico.linop import CircularConvolve +from scico.linop.xray import XRayTransform2D from scico.numpy import Array -try: - import astra # noqa: F401 -except ImportError: - have_astra = False -else: - have_astra = True - from scico.linop.xray.astra import XRayTransform2D - class Foam2(UnitCircle): """Foam-like material with two attenuations. @@ -218,10 +211,8 @@ def generate_ct_data( - **sino** : (:class:`jax.Array`): Corresponding sinograms. - **fbp** : (:class:`jax.Array`) Corresponding filtered back projections. """ - if not (have_ray and have_xdesign and have_astra): - raise RuntimeError( - "Packages ray, xdesign, and astra are required for use of this function." - ) + if not (have_ray and have_xdesign): + raise RuntimeError("Packages ray and xdesign are required for use of this function.") # Generate input data. start_time = time() @@ -234,17 +225,17 @@ def generate_ct_data( # Configure a CT projection operator to generate synthetic measurements. angles = np.linspace(0, jnp.pi, nproj) # evenly spaced projection angles - gt_sh = (size, size) - detector_spacing = 1.0 - A = XRayTransform2D(gt_sh, size, detector_spacing, angles) # X-ray transform operator - + gt_shape = (size, size) + dx = 1.0 / np.sqrt(2) + det_count = int(size * 1.05 / np.sqrt(2.0)) + A = XRayTransform2D(gt_shape, angles, dx=dx, det_count=det_count) # Compute sinograms in parallel. start_time = time() if nproc > 1: # shard array imgshd = img.reshape((nproc, -1, size, size, 1)) sinoshd = batched_f(A, imgshd) - sino = sinoshd.reshape((-1, nproj, size, 1)) + sino = sinoshd.reshape((-1, nproj, sinoshd.shape[-2], 1)) else: sino = vector_f(A, img) @@ -261,8 +252,8 @@ def generate_ct_data( # Normalize sinogram. sino = sino / size - # Shift FBP to [0,1] range. - fbp = (fbp - fbp.min()) / (fbp.max() - fbp.min()) + # Clip FBP to [0,1] range. + fbp = np.clip(fbp, 0, 1) if verbose: # pragma: no cover platform = get_backend().platform diff --git a/scico/functional/_functional.py b/scico/functional/_functional.py index 256791a7..1cbef6fc 100644 --- a/scico/functional/_functional.py +++ b/scico/functional/_functional.py @@ -117,7 +117,7 @@ def conj_prox( return v - lam * self.prox(v / lam, 1.0 / lam, **kwargs) def grad(self, x: Union[Array, BlockArray]): - r"""Evaluates the gradient of this functional at :math:`\mb{x}`. + r"""Evaluate the gradient of this functional at :math:`\mb{x}`. Args: x: Point at which to evaluate gradient. diff --git a/scico/linop/xray/_xray.py b/scico/linop/xray/_xray.py index 83bd8462..9d459db5 100644 --- a/scico/linop/xray/_xray.py +++ b/scico/linop/xray/_xray.py @@ -9,7 +9,7 @@ from functools import partial -from typing import Optional +from typing import Optional, Tuple from warnings import warn import numpy as np @@ -273,14 +273,14 @@ def _back_project( y[jnp.arange(len(angles)).reshape(-1, 1, 1), inds + 1] * weights_valid, axis=0 ) - return HTy + return HTy.astype(jnp.float32) @staticmethod @partial(jax.jit, static_argnames=["nx"]) @partial(jax.vmap, in_axes=(None, None, None, 0, None)) def _calc_weights( - x0: ArrayLike, dx: ArrayLike, nx: Shape, angle: float, y0: float - ) -> snp.Array: + x0: ArrayLike, dx: ArrayLike, nx: Shape, angles: ArrayLike, y0: float + ) -> Tuple[snp.Array, snp.Array]: """ Args: @@ -288,12 +288,12 @@ def _calc_weights( dx: Pixel side length in x- and y-direction. Units are such that the detector bins have length 1.0. nx: Input image shape. - angle: (num_angles,) array of angles in radians. Pixels are + angles: (num_angles,) array of angles in radians. Pixels are projected onto units vectors pointing in these directions. (This argument is `vmap`ed.) y0: Location of the edge of the first detector bin. """ - u = [jnp.cos(angle), jnp.sin(angle)] + u = [jnp.cos(angles), jnp.sin(angles)] Px0 = x0[0] * u[0] + x0[1] * u[1] - y0 Pdx = [dx[0] * u[0], dx[1] * u[1]] Pxmin = jnp.min(jnp.array([Px0, Px0 + Pdx[0], Px0 + Pdx[1], Px0 + Pdx[0] + Pdx[1]])) diff --git a/scico/plot.py b/scico/plot.py index 6c0375be..72ae4962 100644 --- a/scico/plot.py +++ b/scico/plot.py @@ -13,6 +13,7 @@ # This module is copied from https://github.com/bwohlberg/sporco +import os import sys import numpy as np @@ -820,7 +821,8 @@ def config_notebook_plotting(): Configure plotting functions for inline plotting within a Jupyter Notebook shell. This function has no effect when not within a notebook shell, and may therefore be used within a normal python - script. + script. If environment variable ``MATPLOTLIB_IPYNB_BACKEND`` is set, + the matplotlib backend is explicitly set to the specified value. """ # Check whether running within a notebook shell and have @@ -828,8 +830,9 @@ def config_notebook_plotting(): module = sys.modules[__name__] if _in_notebook() and module.plot.__name__ == "plot": - # Set inline backend (i.e. %matplotlib inline) if in a notebook shell - set_notebook_plot_backend() + # Set backend if specified by environment variable + if "MATPLOTLIB_IPYNB_BACKEND" in os.environ: + set_notebook_plot_backend(os.environ["MATPLOTLIB_IPYNB_BACKEND"]) # Replace plot function with a wrapper function that discards # its return value (within a notebook with inline plotting, plots diff --git a/scico/test/flax/test_examples_flax.py b/scico/test/flax/test_examples_flax.py index 72c084dd..ce265d05 100644 --- a/scico/test/flax/test_examples_flax.py +++ b/scico/test/flax/test_examples_flax.py @@ -12,7 +12,6 @@ generate_ct_data, generate_foam1_images, generate_foam2_images, - have_astra, have_ray, have_xdesign, ) @@ -75,8 +74,8 @@ def random_data_gen(seed, N, ndata): @pytest.mark.skipif( - not have_astra or not have_ray or not have_xdesign, - reason="astra, ray, or xdesign package not installed", + not have_ray or not have_xdesign, + reason="ray or xdesign package not installed", ) def test_ct_data_generation(): N = 32 @@ -90,7 +89,7 @@ def random_img_gen(seed, size, ndata): img, sino, fbp = generate_ct_data(nimg, N, nproj, imgfunc=random_img_gen) assert img.shape == (nimg, N, N, 1) - assert sino.shape == (nimg, nproj, N, 1) + assert sino.shape == (nimg, nproj, sino.shape[2], 1) assert fbp.shape == (nimg, N, N, 1) diff --git a/scico/test/flax/test_inv.py b/scico/test/flax/test_inv.py index b63a88a8..d49df64b 100644 --- a/scico/test/flax/test_inv.py +++ b/scico/test/flax/test_inv.py @@ -6,18 +6,12 @@ import jax.numpy as jnp from jax import lax -import pytest - from scico import flax as sflax from scico import random from scico.flax.examples import PaddedCircularConvolve, build_blur_kernel -from scico.flax.examples.data_generation import have_astra from scico.flax.train.traversals import clip_positive, clip_range, construct_traversal from scico.linop import CircularConvolve, Identity - -if have_astra: - from scico.linop.xray.astra import XRayTransform2D - +from scico.linop.xray import XRayTransform2D os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" @@ -153,7 +147,6 @@ def test_train_odpdcnv_default(self): np.testing.assert_array_less(1e-2 * np.ones(alphaval.shape), alphaval) -@pytest.mark.skipif(not have_astra, reason="astra package not installed") class TestCT: def setup_method(self, method): self.N = 32 # signal size @@ -162,11 +155,10 @@ def setup_method(self, method): xt, key = random.randn((2 * self.bsize, self.N, self.N, self.chn), seed=4321) self.nproj = 60 # number of projections - angles = np.linspace(0, np.pi, self.nproj) # evenly spaced projection angles + angles = np.linspace(0, np.pi, self.nproj, endpoint=False, dtype=np.float32) self.opCT = XRayTransform2D( input_shape=(self.N, self.N), det_count=self.N, - det_spacing=1.0, angles=angles, ) # Radon transform operator a_f = lambda v: jnp.atleast_3d(self.opCT(v.squeeze()))