Skip to content

Commit

Permalink
Added a new Functional for TV Norm implementing its proximal operator…
Browse files Browse the repository at this point in the history
… using the fast subiteration free algorithm proposed by Kamilov, 2016
  • Loading branch information
Salman Naqvi committed Oct 5, 2023
1 parent 216ffc8 commit 1b7f583
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 0 deletions.
2 changes: 2 additions & 0 deletions scico/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
L21Norm,
NuclearNorm,
L1MinusL2Norm,
TV2DNorm,
)
from ._indicator import NonNegativeIndicator, L2BallIndicator
from ._denoiser import BM3D, BM4D, DnCNN
Expand All @@ -46,6 +47,7 @@
"BM3D",
"BM4D",
"DnCNN",
"TV2DNorm",
]

# Imported items in __all__ appear to originate in top-level functional module
Expand Down
104 changes: 104 additions & 0 deletions scico/functional/_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from scico.numpy import Array, BlockArray, count_nonzero
from scico.numpy.linalg import norm
from scico.numpy.util import no_nan_divide
from scico.linop import FiniteDifference

from ._functional import Functional

Expand Down Expand Up @@ -477,3 +478,106 @@ def prox(
svdU, svdS, svdV = snp.linalg.svd(v, full_matrices=False)
svdS = snp.maximum(0, svdS - lam)
return svdU @ snp.diag(svdS) @ svdV


class TV2DNorm(Functional):
r"""The :math:`\ell_{TV}` norm.
For a :math:`M \times N` matrix, :math:`\mb{A}`, by default,
.. math::
\norm{\mb{A}}_{TV} = \sum_{n=1}^N \sum_{m=1}^M
\abs{\nabla{A}_{m,n}} \;.
This norm currently only has proximal operator defined only for
2 dimensional data.
For `BlockArray` inputs, the :math:`\ell_{TV}` norm follows the
reduction rules described in :class:`BlockArray`.
A typical use case is computing the anisotropic total variation norm.
"""

has_eval = True
has_prox = True

def __init__(self, dims, tau: float = 1.0):
r"""
Args:
tau: Parameter :math:`\tau` in the norm definition.
"""
self.dims = dims
self.tau = tau

def __call__(self, x: Union[Array, BlockArray]) -> float:
r"""Return the :math:`\ell_{TV}` norm of an array."""
y = 0
gradOp = FiniteDifference(self.dims, input_dtype=x.dtype, circular=True)
grads = gradOp @ x
for g in grads:
y += snp.abs(g)
return self.tau * snp.sum(y)

def prox(
self, x: Union[Array, BlockArray], lam: float = 1.0, **kwargs
) -> Union[Array, BlockArray]:
r"""Proximal operator of the :math:`\ell_{TV}` norm.
Evaluate proximal operator of the TV norm
:cite:`tip-2016-kamilov`.
Args:
v: Input array :math:`\mb{v}`.
lam: Proximal parameter :math:`\lam`.
kwargs: Additional arguments that may be used by derived
classes.
"""
D = 2
K = 2*D
thresh = snp.sqrt(2) * K * self.tau * lam

y = snp.zeros_like(x)
for ax in range(2):
y = y.at[:].add(self.iht2(self.shrink(self.ht2(x, axis=ax, shift=False), thresh), axis=ax, shift=False))
y = y.at[:].add(self.iht2(self.shrink(self.ht2(x, axis=ax, shift=True), thresh), axis=ax, shift=True))
y = y.at[:].divide(K)

return y

def ht2(self, x, axis, shift):
s = x.shape
w = snp.zeros(s)
C = 1 / snp.sqrt(2)
if shift:
x = snp.roll(x, -1, axis=axis)

m = s[axis] // 2
if not axis:
w = w.at[:m, :].set(C * (x[1::2, :] + x[::2, :]))
w = w.at[m:, :].set(C * (x[1::2, :] - x[::2, :]))
else:
w = w.at[:, :m].set(C * (x[:, 1::2] + x[:, ::2]))
w = w.at[:, m:].set(C * (x[:, 1::2] - x[:, ::2]))
return w

def iht2(self, w, axis, shift):
s = snp.shape(w)
y = snp.zeros(s)
C = 1 / snp.sqrt(2)
m = s[axis] // 2
if not axis:
y = y.at[::2, :].set(C * (w[:m, :] - w[m:, :]))
y = y.at[1::2, :].set(C * (w[:m, :] + w[m:, :]))
else:
y = y.at[:, ::2].set(C * (w[:, :m] - w[:, m:]))
y = y.at[:, 1::2].set(C * (w[:, :m] + w[:, m:]))

if shift:
y = snp.roll(y, 1, axis)

return y

def shrink(self, x, tau):
threshed = snp.maximum(snp.abs(x)-tau, 0)
threshed = threshed.at[:].multiply(snp.sign(x))
return threshed

0 comments on commit 1b7f583

Please sign in to comment.