Skip to content

Commit

Permalink
Merge branch 'main' into brendt/issue530
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg committed Jun 12, 2024
2 parents 013fd6f + d990555 commit 5786ce6
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 20 deletions.
55 changes: 40 additions & 15 deletions scico/optimize/_admmaux.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2020-2023 by SCICO Developers
# Copyright (C) 2020-2024 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
Expand Down Expand Up @@ -371,15 +371,15 @@ class CircularConvolveSolver(LinearSubproblemSolver):
r"""Solver for linear operators diagonalized in the DFT domain.
Specialization of :class:`.LinearSubproblemSolver` for the case
where :code:`f` is an instance of :class:`.SquaredL2Loss`, the
forward operator :code:`f.A` is either an instance of
:class:`.Identity` or :class:`.CircularConvolve`, and the
:code:`C_i` are all shift invariant linear operators, examples of
which include instances of :class:`.Identity` as well as some
instances (depending on initializer parameters) of
:class:`.CircularConvolve` and :class:`.FiniteDifference`.
None of the instances of :class:`.CircularConvolve` may sum over any
of their axes.
where :code:`f` is ``None``, or an instance of
:class:`.SquaredL2Loss` with a forward operator :code:`f.A` that is
either an instance of :class:`.Identity` or
:class:`.CircularConvolve`, and the :code:`C_i` are all shift
invariant linear operators, examples of which include instances of
:class:`.Identity` as well as some instances (depending on
initializer parameters) of :class:`.CircularConvolve` and
:class:`.FiniteDifference`. None of the instances of
:class:`.CircularConvolve` may sum over any of their axes.
Attributes:
admm (:class:`.ADMM`): ADMM solver object to which the solver is
Expand All @@ -388,11 +388,29 @@ class CircularConvolveSolver(LinearSubproblemSolver):
equation to be solved.
"""

def __init__(self):
"""Initialize a :class:`CircularConvolveSolver` object."""
def __init__(self, ndims: Optional[int] = None):
"""Initialize a :class:`CircularConvolveSolver` object.
Args:
ndims: Number of trailing dimensions of the input and kernel
involved in the :class:`.CircularConvolve` convolutions.
In most cases this value is automatically determined from
the optimization problem specification, but this is not
possible when :code:`f` is ``None`` and none of the
:code:`C_i` are of type :class:`.CircularConvolve`. When
not ``None``, this parameter overrides the automatic
mechanism.
"""
self.ndims = ndims

def internal_init(self, admm: soa.ADMM):
if admm.f is not None:
if admm.f is None:
is_cc = [isinstance(C, CircularConvolve) for C in admm.C_list]
if any(is_cc):
auto_ndims = admm.C_list[is_cc.index(True)].ndims
else:
auto_ndims = None
else:
if not isinstance(admm.f, SquaredL2Loss):
raise TypeError(
"CircularConvolveSolver requires f to be a scico.loss.SquaredL2Loss; "
Expand All @@ -403,20 +421,27 @@ def internal_init(self, admm: soa.ADMM):
"CircularConvolveSolver requires f.A to be a scico.linop.CircularConvolve "
f"or scico.linop.Identity; got {type(admm.f.A)}."
)
auto_ndims = admm.f.A.ndims if isinstance(admm.f.A, CircularConvolve) else None

if self.ndims is None:
self.ndims = auto_ndims
super().internal_init(admm)

self.real_result = is_real_dtype(admm.C_list[0].input_dtype)

# All of the C operators are assumed to be linear and shift invariant
# but this is not checked.
lhs_op_list = [
rho * CircularConvolve.from_operator(C.gram_op)
rho * CircularConvolve.from_operator(C.gram_op, ndims=self.ndims)
for rho, C in zip(admm.rho_list, admm.C_list)
]
A_lhs = reduce(lambda a, b: a + b, lhs_op_list)
if self.admm.f is not None:
A_lhs += 2.0 * admm.f.scale * CircularConvolve.from_operator(admm.f.A.gram_op)
A_lhs += (
2.0
* admm.f.scale
* CircularConvolve.from_operator(admm.f.A.gram_op, ndims=self.ndims)
)

self.A_lhs = A_lhs

Expand Down
70 changes: 65 additions & 5 deletions scico/test/optimize/test_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,23 +363,29 @@ def test_admm_quadratic_matrix(self):
assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-5


@pytest.mark.parametrize("extra_axis", (False, True))
@pytest.mark.parametrize("center", (None, [-1.0, 2.5]))
class TestCircularConvolveSolve:
def setup_method(self, method):

@pytest.fixture(scope="function", autouse=True)
def setup_and_teardown(self, extra_axis, center):
np.random.seed(12345)
Nx = 8
x = np.pad(np.ones((Nx, Nx), dtype=np.float32), Nx)
x = snp.pad(snp.ones((Nx, Nx), dtype=np.float32), Nx)
Npsf = 3
psf = snp.ones((Npsf, Npsf), dtype=np.float32) / (Npsf**2)
if extra_axis:
x = x[np.newaxis]
psf = psf[np.newaxis]
self.A = linop.CircularConvolve(
h=psf,
input_shape=x.shape,
input_dtype=np.float32,
h=psf, input_shape=x.shape, ndims=2, input_dtype=np.float32, h_center=center
)
self.y = self.A(x)
λ = 1e-2
self.f = loss.SquaredL2Loss(y=self.y, A=self.A)
self.g_list = [λ * functional.L1Norm()]
self.C_list = [linop.FiniteDifference(input_shape=x.shape, circular=True)]
yield

def test_admm(self):
maxiter = 50
Expand All @@ -406,6 +412,60 @@ def test_admm(self):
x0=self.A.adj(self.y),
subproblem_solver=CircularConvolveSolver(),
)
assert admm_dft.subproblem_solver.A_lhs.ndims == 2
x_dft = admm_dft.solve()
np.testing.assert_allclose(x_dft, x_lin, atol=1e-4, rtol=0)
assert metric.mse(x_lin, x_dft) < 1e-9


@pytest.mark.parametrize("with_cconv", (False, True))
class TestSpecialCaseCircularConvolveSolve:

@pytest.fixture(scope="function", autouse=True)
def setup_and_teardown(self, with_cconv):
np.random.seed(12345)
Nx = 8
x = snp.pad(snp.ones((1, Nx, Nx), dtype=np.float32), Nx)
if with_cconv:
Npsf = 3
psf = snp.ones((1, Npsf, Npsf), dtype=np.float32) / (Npsf**2)
C0 = linop.CircularConvolve(h=psf, input_shape=x.shape, ndims=2, input_dtype=np.float32)
else:
C0 = linop.FiniteDifference(input_shape=x.shape, axes=(1, 2), circular=True)
C1 = linop.Identity(input_shape=x.shape)
self.y = C0(x)
self.g_list = [loss.SquaredL2Loss(y=self.y), functional.L2Norm()]
self.C_list = [C0, C1]
self.with_cconv = with_cconv
yield

def test_admm(self):
maxiter = 50
ρ = 1e-1
rho_list = [ρ, ρ]
admm_lin = ADMM(
f=None,
g_list=self.g_list,
C_list=self.C_list,
rho_list=rho_list,
maxiter=maxiter,
itstat_options={"display": False},
x0=self.C_list[0].adj(self.y),
subproblem_solver=LinearSubproblemSolver(),
)
x_lin = admm_lin.solve()
ndims = None if self.with_cconv else 2
admm_dft = ADMM(
f=None,
g_list=self.g_list,
C_list=self.C_list,
rho_list=rho_list,
maxiter=maxiter,
itstat_options={"display": False},
x0=self.C_list[0].adj(self.y),
subproblem_solver=CircularConvolveSolver(ndims=ndims),
)
assert admm_dft.subproblem_solver.A_lhs.ndims == 2
x_dft = admm_dft.solve()
np.testing.assert_allclose(x_dft, x_lin, atol=1e-4, rtol=0)
assert metric.mse(x_lin, x_dft) < 1e-9
Expand Down

0 comments on commit 5786ce6

Please sign in to comment.