Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve #528 #532

Merged
merged 5 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 40 additions & 15 deletions scico/optimize/_admmaux.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2020-2023 by SCICO Developers
# Copyright (C) 2020-2024 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
Expand Down Expand Up @@ -371,15 +371,15 @@ class CircularConvolveSolver(LinearSubproblemSolver):
r"""Solver for linear operators diagonalized in the DFT domain.

Specialization of :class:`.LinearSubproblemSolver` for the case
where :code:`f` is an instance of :class:`.SquaredL2Loss`, the
forward operator :code:`f.A` is either an instance of
:class:`.Identity` or :class:`.CircularConvolve`, and the
:code:`C_i` are all shift invariant linear operators, examples of
which include instances of :class:`.Identity` as well as some
instances (depending on initializer parameters) of
:class:`.CircularConvolve` and :class:`.FiniteDifference`.
None of the instances of :class:`.CircularConvolve` may sum over any
of their axes.
where :code:`f` is ``None``, or an instance of
:class:`.SquaredL2Loss` with a forward operator :code:`f.A` that is
either an instance of :class:`.Identity` or
:class:`.CircularConvolve`, and the :code:`C_i` are all shift
invariant linear operators, examples of which include instances of
:class:`.Identity` as well as some instances (depending on
initializer parameters) of :class:`.CircularConvolve` and
:class:`.FiniteDifference`. None of the instances of
:class:`.CircularConvolve` may sum over any of their axes.

Attributes:
admm (:class:`.ADMM`): ADMM solver object to which the solver is
Expand All @@ -388,11 +388,29 @@ class CircularConvolveSolver(LinearSubproblemSolver):
equation to be solved.
"""

def __init__(self):
"""Initialize a :class:`CircularConvolveSolver` object."""
def __init__(self, ndims: Optional[int] = None):
"""Initialize a :class:`CircularConvolveSolver` object.

Args:
ndims: Number of trailing dimensions of the input and kernel
involved in the :class:`.CircularConvolve` convolutions.
In most cases this value is automatically determined from
the optimization problem specification, but this is not
possible when :code:`f` is ``None`` and none of the
:code:`C_i` are of type :class:`.CircularConvolve`. When
not ``None``, this parameter overrides the automatic
mechanism.
"""
self.ndims = ndims

def internal_init(self, admm: soa.ADMM):
if admm.f is not None:
if admm.f is None:
is_cc = [isinstance(C, CircularConvolve) for C in admm.C_list]
if any(is_cc):
auto_ndims = admm.C_list[is_cc.index(True)].ndims
else:
auto_ndims = None
else:
if not isinstance(admm.f, SquaredL2Loss):
raise TypeError(
"CircularConvolveSolver requires f to be a scico.loss.SquaredL2Loss; "
Expand All @@ -403,20 +421,27 @@ def internal_init(self, admm: soa.ADMM):
"CircularConvolveSolver requires f.A to be a scico.linop.CircularConvolve "
f"or scico.linop.Identity; got {type(admm.f.A)}."
)
auto_ndims = admm.f.A.ndims if isinstance(admm.f.A, CircularConvolve) else None

if self.ndims is None:
self.ndims = auto_ndims
super().internal_init(admm)

self.real_result = is_real_dtype(admm.C_list[0].input_dtype)

# All of the C operators are assumed to be linear and shift invariant
# but this is not checked.
lhs_op_list = [
rho * CircularConvolve.from_operator(C.gram_op)
rho * CircularConvolve.from_operator(C.gram_op, ndims=self.ndims)
for rho, C in zip(admm.rho_list, admm.C_list)
]
A_lhs = reduce(lambda a, b: a + b, lhs_op_list)
if self.admm.f is not None:
A_lhs += 2.0 * admm.f.scale * CircularConvolve.from_operator(admm.f.A.gram_op)
A_lhs += (
2.0
* admm.f.scale
* CircularConvolve.from_operator(admm.f.A.gram_op, ndims=self.ndims)
)

self.A_lhs = A_lhs

Expand Down
70 changes: 65 additions & 5 deletions scico/test/optimize/test_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,23 +363,29 @@ def test_admm_quadratic_matrix(self):
assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-5


@pytest.mark.parametrize("extra_axis", (False, True))
@pytest.mark.parametrize("center", (None, [-1.0, 2.5]))
class TestCircularConvolveSolve:
def setup_method(self, method):

@pytest.fixture(scope="function", autouse=True)
def setup_and_teardown(self, extra_axis, center):
np.random.seed(12345)
Nx = 8
x = np.pad(np.ones((Nx, Nx), dtype=np.float32), Nx)
x = snp.pad(snp.ones((Nx, Nx), dtype=np.float32), Nx)
Npsf = 3
psf = snp.ones((Npsf, Npsf), dtype=np.float32) / (Npsf**2)
if extra_axis:
x = x[np.newaxis]
psf = psf[np.newaxis]
self.A = linop.CircularConvolve(
h=psf,
input_shape=x.shape,
input_dtype=np.float32,
h=psf, input_shape=x.shape, ndims=2, input_dtype=np.float32, h_center=center
)
self.y = self.A(x)
λ = 1e-2
self.f = loss.SquaredL2Loss(y=self.y, A=self.A)
self.g_list = [λ * functional.L1Norm()]
self.C_list = [linop.FiniteDifference(input_shape=x.shape, circular=True)]
yield

def test_admm(self):
maxiter = 50
Expand All @@ -406,6 +412,60 @@ def test_admm(self):
x0=self.A.adj(self.y),
subproblem_solver=CircularConvolveSolver(),
)
assert admm_dft.subproblem_solver.A_lhs.ndims == 2
x_dft = admm_dft.solve()
np.testing.assert_allclose(x_dft, x_lin, atol=1e-4, rtol=0)
assert metric.mse(x_lin, x_dft) < 1e-9


@pytest.mark.parametrize("with_cconv", (False, True))
class TestSpecialCaseCircularConvolveSolve:

@pytest.fixture(scope="function", autouse=True)
def setup_and_teardown(self, with_cconv):
np.random.seed(12345)
Nx = 8
x = snp.pad(snp.ones((1, Nx, Nx), dtype=np.float32), Nx)
if with_cconv:
Npsf = 3
psf = snp.ones((1, Npsf, Npsf), dtype=np.float32) / (Npsf**2)
C0 = linop.CircularConvolve(h=psf, input_shape=x.shape, ndims=2, input_dtype=np.float32)
else:
C0 = linop.FiniteDifference(input_shape=x.shape, axes=(1, 2), circular=True)
C1 = linop.Identity(input_shape=x.shape)
self.y = C0(x)
self.g_list = [loss.SquaredL2Loss(y=self.y), functional.L2Norm()]
self.C_list = [C0, C1]
self.with_cconv = with_cconv
yield

def test_admm(self):
maxiter = 50
ρ = 1e-1
rho_list = [ρ, ρ]
admm_lin = ADMM(
f=None,
g_list=self.g_list,
C_list=self.C_list,
rho_list=rho_list,
maxiter=maxiter,
itstat_options={"display": False},
x0=self.C_list[0].adj(self.y),
subproblem_solver=LinearSubproblemSolver(),
)
x_lin = admm_lin.solve()
ndims = None if self.with_cconv else 2
admm_dft = ADMM(
f=None,
g_list=self.g_list,
C_list=self.C_list,
rho_list=rho_list,
maxiter=maxiter,
itstat_options={"display": False},
x0=self.C_list[0].adj(self.y),
subproblem_solver=CircularConvolveSolver(ndims=ndims),
)
assert admm_dft.subproblem_solver.A_lhs.ndims == 2
x_dft = admm_dft.solve()
np.testing.assert_allclose(x_dft, x_lin, atol=1e-4, rtol=0)
assert metric.mse(x_lin, x_dft) < 1e-9
Expand Down
Loading