diff --git a/scico/linop/__init__.py b/scico/linop/__init__.py index 598a26aa2..ee6c78e36 100644 --- a/scico/linop/__init__.py +++ b/scico/linop/__init__.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2021-2023 by SCICO Developers +# Copyright (C) 2021-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -12,7 +12,7 @@ from ._circconv import CircularConvolve from ._convolve import Convolve, ConvolveByX from ._dft import DFT -from ._diag import Diagonal, Identity +from ._diag import Diagonal, Identity, ScaledIdentity from ._diff import FiniteDifference, SingleAxisFiniteDifference from ._func import Crop, Pad, Reshape, Slice, Sum, Transpose, linop_from_function from ._linop import ComposedLinearOperator, LinearOperator @@ -35,6 +35,7 @@ "Pad", "Crop", "Reshape", + "ScaledIdentity", "Slice", "Sum", "Transpose", diff --git a/scico/linop/_circconv.py b/scico/linop/_circconv.py index deacce792..a6456b3f6 100644 --- a/scico/linop/_circconv.py +++ b/scico/linop/_circconv.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2021-2023 by SCICO Developers +# Copyright (C) 2021-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -8,8 +8,6 @@ """Circular convolution linear operators.""" import math -import operator -from functools import partial from typing import Optional, Sequence, Tuple, Union import numpy as np @@ -205,7 +203,7 @@ def _adj(self, x: snp.Array) -> snp.Array: # type: ignore H_adj_x = H_adj_x.real return H_adj_x - @partial(_wrap_add_sub, op=operator.add) + @_wrap_add_sub def __add__(self, other): if self.ndims != other.ndims: raise ValueError(f"Incompatible ndims: {self.ndims} != {other.ndims}.") @@ -218,7 +216,7 @@ def __add__(self, other): h_is_dft=True, ) - @partial(_wrap_add_sub, op=operator.sub) + @_wrap_add_sub def __sub__(self, other): if self.ndims != other.ndims: raise ValueError(f"Incompatible ndims: {self.ndims} != {other.ndims}.") @@ -241,10 +239,6 @@ def __mul__(self, scalar): h_is_dft=True, ) - @_wrap_mul_div_scalar - def __rmul__(self, scalar): - return self * scalar - @_wrap_mul_div_scalar def __truediv__(self, scalar): return CircularConvolve( diff --git a/scico/linop/_convolve.py b/scico/linop/_convolve.py index 01f8789de..ab90690c7 100644 --- a/scico/linop/_convolve.py +++ b/scico/linop/_convolve.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2020-2023 by SCICO Developers +# Copyright (C) 2020-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -12,9 +12,6 @@ # see https://www.python.org/dev/peps/pep-0563/ from __future__ import annotations -import operator -from functools import partial - import numpy as np from jax.dtypes import result_type @@ -85,7 +82,7 @@ def __init__( def _eval(self, x: snp.Array) -> snp.Array: return convolve(x, self.h, mode=self.mode) - @partial(_wrap_add_sub, op=operator.add) + @_wrap_add_sub def __add__(self, other): if self.mode != other.mode: raise ValueError(f"Incompatible modes: {self.mode} != {other.mode}.") @@ -102,7 +99,7 @@ def __add__(self, other): raise ValueError(f"Incompatible shapes: {self.shape} != {other.shape}.") - @partial(_wrap_add_sub, op=operator.sub) + @_wrap_add_sub def __sub__(self, other): if self.mode != other.mode: raise ValueError(f"Incompatible modes: {self.mode} != {other.mode}.") @@ -129,17 +126,6 @@ def __mul__(self, scalar): adj_fn=lambda x: snp.conj(scalar) * self.adj(x), ) - @_wrap_mul_div_scalar - def __rmul__(self, scalar): - return Convolve( - h=self.h * scalar, - input_shape=self.input_shape, - input_dtype=result_type(self.input_dtype, type(scalar)), - mode=self.mode, - output_shape=self.output_shape, - adj_fn=lambda x: snp.conj(scalar) * self.adj(x), - ) - @_wrap_mul_div_scalar def __truediv__(self, scalar): return Convolve( @@ -216,7 +202,7 @@ def __init__( def _eval(self, h: snp.Array) -> snp.Array: return convolve(self.x, h, mode=self.mode) - @partial(_wrap_add_sub, op=operator.add) + @_wrap_add_sub def __add__(self, other): if self.mode != other.mode: raise ValueError(f"Incompatible modes: {self.mode} != {other.mode}.") @@ -231,7 +217,7 @@ def __add__(self, other): ) raise ValueError(f"Incompatible shapes: {self.shape} != {other.shape}.") - @partial(_wrap_add_sub, op=operator.sub) + @_wrap_add_sub def __sub__(self, other): if self.mode != other.mode: raise ValueError(f"Incompatible modes: {self.mode} != {other.mode}.") @@ -259,17 +245,6 @@ def __mul__(self, scalar): adj_fn=lambda x: snp.conj(scalar) * self.adj(x), ) - @_wrap_mul_div_scalar - def __rmul__(self, scalar): - return ConvolveByX( - x=self.x * scalar, - input_shape=self.input_shape, - input_dtype=result_type(self.input_dtype, type(scalar)), - mode=self.mode, - output_shape=self.output_shape, - adj_fn=lambda x: snp.conj(scalar) * self.adj(x), - ) - @_wrap_mul_div_scalar def __truediv__(self, scalar): return ConvolveByX( diff --git a/scico/linop/_diag.py b/scico/linop/_diag.py index 0e14a0068..d9b179b83 100644 --- a/scico/linop/_diag.py +++ b/scico/linop/_diag.py @@ -1,4 +1,4 @@ -# Copyright (C) 2020-2023 by SCICO Developers +# Copyright (C) 2020-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -11,8 +11,6 @@ # see https://www.python.org/dev/peps/pep-0563/ from __future__ import annotations -import operator -from functools import partial from typing import Optional, Union import scico.numpy as snp @@ -23,10 +21,7 @@ from ._linop import LinearOperator, _wrap_add_sub -__all__ = [ - "Diagonal", - "Identity", -] +__all__ = ["Diagonal", "Identity", "ScaledIdentity"] class Diagonal(LinearOperator): @@ -35,7 +30,7 @@ class Diagonal(LinearOperator): def __init__( self, diagonal: Union[Array, BlockArray], - input_shape: Optional[Shape] = None, + input_shape: Optional[Union[Shape, BlockShape]] = None, input_dtype: Optional[DType] = None, **kwargs, ): @@ -48,18 +43,18 @@ def __init__( input_dtype: `dtype` of input argument. The default, ``None``, means `diagonal.dtype`. """ - self.diagonal = diagonal + self._diagonal = diagonal if input_shape is None: - input_shape = self.diagonal.shape + input_shape = self._diagonal.shape if input_dtype is None: - input_dtype = self.diagonal.dtype + input_dtype = self._diagonal.dtype if isinstance(diagonal, BlockArray) and is_nested(input_shape): - output_shape = broadcast_nested_shapes(input_shape, self.diagonal.shape) + output_shape = broadcast_nested_shapes(input_shape, self._diagonal.shape) elif not isinstance(diagonal, BlockArray) and not is_nested(input_shape): - output_shape = snp.broadcast_shapes(input_shape, self.diagonal.shape) + output_shape = snp.broadcast_shapes(input_shape, self._diagonal.shape) elif isinstance(diagonal, BlockArray): raise ValueError("Parameter diagonal was a BlockArray but input_shape was not nested.") else: @@ -73,8 +68,13 @@ def __init__( **kwargs, ) - def _eval(self, x): - return x * self.diagonal + def _eval(self, x: Union[Array, BlockArray]) -> Union[Array, BlockArray]: + return self._diagonal * x + + @property + def diagonal(self) -> Union[Array, BlockArray]: + """Return an array representing the diagonal component.""" + return self._diagonal @property def T(self) -> Diagonal: @@ -99,13 +99,13 @@ def gram_op(self) -> Diagonal: """ return Diagonal(diagonal=self.diagonal.conj() * self.diagonal) - @partial(_wrap_add_sub, op=operator.add) + @_wrap_add_sub def __add__(self, other): if self.diagonal.shape == other.diagonal.shape: return Diagonal(diagonal=self.diagonal + other.diagonal) raise ValueError(f"Incompatible shapes: {self.shape} != {other.shape}.") - @partial(_wrap_add_sub, op=operator.sub) + @_wrap_add_sub def __sub__(self, other): if self.diagonal.shape == other.diagonal.shape: return Diagonal(diagonal=self.diagonal - other.diagonal) @@ -115,10 +115,6 @@ def __sub__(self, other): def __mul__(self, scalar): return Diagonal(diagonal=self.diagonal * scalar) - @_wrap_mul_div_scalar - def __rmul__(self, scalar): - return Diagonal(diagonal=self.diagonal * scalar) - @_wrap_mul_div_scalar def __truediv__(self, scalar): return Diagonal(diagonal=self.diagonal / scalar) @@ -128,9 +124,7 @@ def __matmul__(self, other): if isinstance(other, Diagonal): if self.shape == other.shape: return Diagonal(diagonal=self.diagonal * other.diagonal) - raise ValueError(f"Shapes {self.shape} and {other.shape} do not match.") - else: return self(other) @@ -156,10 +150,126 @@ def norm(self, ord=None): # pylint: disable=W0622 mord = snp.inf if mord not in ordfunc: raise ValueError(f"Invalid value {ord} for parameter ord.") - return ordfunc[mord](self.diagonal) + return ordfunc[mord](self._diagonal) + + +class ScaledIdentity(Diagonal): + """Scaled identity operator.""" + + def __init__( + self, + scalar: float, + input_shape: Union[Shape, BlockShape], + input_dtype: DType = snp.float32, + **kwargs, + ): + """ + Args: + scalar: Scaling of the identity. + input_shape: Shape of input array. + input_dtype: `dtype` of input argument. + """ + if is_nested(input_shape): + diagonal = scalar * snp.ones(((),) * len(input_shape), dtype=input_dtype) + else: + diagonal = scalar * snp.ones((), dtype=input_dtype) + super().__init__( + diagonal=diagonal, + input_shape=input_shape, + input_dtype=input_dtype, + **kwargs, + ) + + @property + def diagonal(self) -> Union[Array, BlockArray]: + return self._diagonal * snp.ones(self.input_shape, dtype=self.input_dtype) + + def conj(self) -> ScaledIdentity: + """Complex conjugate of this :class:`ScaledIdentity`.""" + return ScaledIdentity( + scalar=self._diagonal.conj(), input_shape=self.input_shape, input_dtype=self.input_dtype + ) + + @property + def gram_op(self) -> ScaledIdentity: + """Gram operator of this :class:`ScaledIdentity`.""" + return ScaledIdentity( + scalar=self._diagonal * self._diagonal.conj(), + input_shape=self.input_shape, + input_dtype=self.input_dtype, + ) + + @_wrap_add_sub + def __add__(self, other): + if self.input_shape == other.input_shape: + return ScaledIdentity( + scalar=self._diagonal + other._diagonal, + input_shape=self.input_shape, + input_dtype=self.input_dtype, + ) + raise ValueError(f"Incompatible shapes: {self.shape} != {other.shape}.") + + @_wrap_add_sub + def __sub__(self, other): + if self.input_shape == other.input_shape: + return ScaledIdentity( + scalar=self._diagonal - other._diagonal, + input_shape=self.input_shape, + input_dtype=self.input_dtype, + ) + raise ValueError(f"Incompatible shapes: {self.shape} != {other.shape}.") + + @_wrap_mul_div_scalar + def __mul__(self, scalar): + return ScaledIdentity( + scalar=self._diagonal * scalar, + input_shape=self.input_shape, + input_dtype=self.input_dtype, + ) + + @_wrap_mul_div_scalar + def __truediv__(self, scalar): + return ScaledIdentity( + scalar=self._diagonal / scalar, + input_shape=self.input_shape, + input_dtype=self.input_dtype, + ) + + def __matmul__(self, other): + # self @ other + if isinstance(other, Diagonal): + if self.shape != other.shape: + raise ValueError(f"Shapes {self.shape} and {other.shape} do not match.") + if isinstance(other, ScaledIdentity): + return ScaledIdentity( + scalar=self._diagonal * other._diagonal, + input_shape=self.input_shape, + input_dtype=self.input_dtype, + ) + else: + return Diagonal(diagonal=self._diagonal * other.diagonal) + else: + return self(other) + + def norm(self, ord=None): # pylint: disable=W0622 + """Compute the matrix norm of the identity operator. + + Valid values of `ord` and the corresponding norm definition + are those listed under "norm for matrices" in the + :func:`scico.numpy.linalg.norm` documentation. + """ + N = self.input_size + if ord is None or ord == "fro": + return snp.abs(self._diagonal) * snp.sqrt(N) + elif ord == "nuc": + return snp.abs(self._diagonal) * N + elif ord in (-snp.inf, -1, -2, 1, 2, snp.inf): + return snp.abs(self._diagonal) + else: + raise ValueError(f"Invalid value {ord} for parameter ord.") -class Identity(Diagonal): +class Identity(ScaledIdentity): """Identity operator.""" def __init__( @@ -168,11 +278,33 @@ def __init__( """ Args: input_shape: Shape of input array. + input_dtype: `dtype` of input argument. """ - super().__init__(diagonal=snp.ones(input_shape, dtype=input_dtype), **kwargs) + super().__init__( + scalar=1.0, + input_shape=input_shape, + input_dtype=input_dtype, + **kwargs, + ) def _eval(self, x: Union[Array, BlockArray]) -> Union[Array, BlockArray]: return x + @property + def diagonal(self) -> Union[Array, BlockArray]: + return snp.ones(self.input_shape, dtype=self.input_dtype) + + def conj(self) -> Identity: + """Complex conjugate of this :class:`Diagonal`.""" + return self + + @property + def gram_op(self) -> Identity: + """Gram operator of this :class:`Identity`.""" + return self + + def __matmul__(self, other): + return other + def __rmatmul__(self, x: Union[Array, BlockArray]) -> Union[Array, BlockArray]: return x diff --git a/scico/linop/_linop.py b/scico/linop/_linop.py index 309f92a9a..54012be6d 100644 --- a/scico/linop/_linop.py +++ b/scico/linop/_linop.py @@ -1,4 +1,4 @@ -# Copyright (C) 2020-2023 by SCICO Developers +# Copyright (C) 2020-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -11,8 +11,7 @@ # see https://www.python.org/dev/peps/pep-0563/ from __future__ import annotations -import operator -from functools import partial, wraps +from functools import wraps from typing import Callable, Optional, Union import numpy as np @@ -29,28 +28,51 @@ from scico.typing import BlockShape, DType, Shape -def _wrap_add_sub(func: Callable, op: Callable) -> Callable: - r"""Wrapper function for defining `__add__`, `__sub__`. - - Wrapper function for defining `__add__`,` __sub__` between - :class:`LinearOperator` and other objects. - - Handles shape checking and dispatching based on operand types: - - - If one of the two operands is an :class:`.Operator`, an - :class:`.Operator` is returned. - - If both operands are :class:`LinearOperator` of different types, - a generic :class:`LinearOperator` is returned. - - If both operands are :class:`LinearOperator` of the same type, a - special constructor can be called +def _wrap_add_sub(func: Callable) -> Callable: + r"""Wrapper function for defining `__add__` and `__sub__`. + + Wrapper function for defining `__add__` and ` __sub__` between + :class:`LinearOperator` and derived classes. Operations + between :class:`LinearOperator` and :class:`.Operator` + types are also supported. + + Handles shape checking and function dispatch based on types of + operands `a` and `b` in the call `func(a, b)`. Note that `func` + will always be a method of the type of `a`, and since this wrapper + should only be applied within :class:`LinearOperator` or derived + classes, we can assume that `a` is always an instance of + :class:`LinearOperator`. The general rule for dispatch is that the + `__add__` or `__sub__` operator of the nearest common base class + of `a` and `b` should be called. If `b` is derived from `a`, this + entails using the operator defined in the class of `a`, and + vice-versa. If one of the operands is not a descendant of the other + in the class hierarchy, then it is assumed that their common base + class is either :class:`.Operator` or :class:`LinearOperator`, + depending on the type of `b`. + + - If `b` is not an instance of :class:`.Operator`, a :exc:`TypeError` + is raised. + - If the shapes of `a` and `b` do not match, a :exc:`ValueError` is + raised. + - If `b` is an instance of the type of `a` then `func(a, b)` is + called where `func` is the argument of this wrapper, i.e. + the unwrapped function defined in the class of `a`. + - If `a` is an instance of the type of `b` then `func(a, b)` is + called where `func` is the unwrapped function defined in the class + of `b`. + - If `b` is a :class:`LinearOperator` then `func(a, b)` is called + where `func` is the operator defined in :class:`LinearOperator`. + - Othwerwise, `func(a, b)` is called where `func` is the operator + defined in :class:`.Operator`. Args: func: should be either `.__add__` or `.__sub__`. - op: functional equivalent of func, ex. op.add for func = - `__add__`. + + Returns: + Wrapped version of `func`. Raises: - ValueError: If the shape of both operators does not match. + ValueError: If the shapes of two operators do not match. TypeError: If one of the two operands is not an :class:`.Operator` or :class:`LinearOperator`. """ @@ -62,31 +84,40 @@ def wrapper( if isinstance(b, Operator): if a.shape == b.shape: if isinstance(b, type(a)): - # same type of linop, eg convolution can have special - # behavior (see Conv2d.__add__) + # b is an instance of the class of a: call the unwrapped operator + # defined in the class of a, which is the func argument of this + # wrapper return func(a, b) - if isinstance( - b, LinearOperator - ): # LinearOperator + LinearOperator -> LinearOperator - return LinearOperator( - input_shape=a.input_shape, - output_shape=a.output_shape, - eval_fn=lambda x: op(a(x), b(x)), - adj_fn=lambda x: op(a(x), b(x)), - input_dtype=a.input_dtype, - output_dtype=result_type(a.output_dtype, b.output_dtype), - ) - # LinearOperator + Operator -> Operator - return Operator( - input_shape=a.input_shape, - output_shape=a.output_shape, - eval_fn=lambda x: op(a(x), b(x)), - input_dtype=a.input_dtype, - output_dtype=result_type(a.output_dtype, b.output_dtype), - ) + if isinstance(a, type(b)): + # a is an instance of class b: call the unwrapped operator + # defined in the class of b. A test is required because + # the operators defined in Operator and non-LinearOperator + # derived classes are not wrapped. + if hasattr(getattr(type(b), func.__name__), "_unwrapped"): + uwfunc = getattr(type(b), func.__name__)._unwrapped + else: + uwfunc = getattr(type(b), func.__name__) + return uwfunc(a, b) + # The most general approach here would be to automatically determine + # the nearest common ancestor of the classes of a and b (e.g. as + # discussed in https://stackoverflow.com/a/58290475 ), but the + # simpler approach adopted here is to just assume that the common + # base of two classes that do not have an ancestor-descendant + # relationship is either Operator or LinearOperator. + if isinstance(b, LinearOperator): + # LinearOperator + LinearOperator -> LinearOperator + uwfunc = getattr(LinearOperator, func.__name__)._unwrapped + return uwfunc(a, b) + # LinearOperator + Operator -> Operator (access to the function + # definition differs from that for LinearOperator because + # Operator __add__ and __sub__ are not wrapped) + uwfunc = getattr(Operator, func.__name__) + return uwfunc(a, b) raise ValueError(f"Shapes {a.shape} and {b.shape} do not match.") raise TypeError(f"Operation {func.__name__} not defined between {type(a)} and {type(b)}.") + wrapper._unwrapped = func # type: ignore + return wrapper @@ -177,7 +208,7 @@ def jit(self): self._adj = jax.jit(self._adj) self._gram = jax.jit(self._gram) - @partial(_wrap_add_sub, op=operator.add) + @_wrap_add_sub def __add__(self, other): return LinearOperator( input_shape=self.input_shape, @@ -188,7 +219,7 @@ def __add__(self, other): output_dtype=result_type(self.output_dtype, other.output_dtype), ) - @partial(_wrap_add_sub, op=operator.sub) + @_wrap_add_sub def __sub__(self, other): return LinearOperator( input_shape=self.input_shape, @@ -212,14 +243,7 @@ def __mul__(self, other): @_wrap_mul_div_scalar def __rmul__(self, other): - return LinearOperator( - input_shape=self.input_shape, - output_shape=self.output_shape, - eval_fn=lambda x: other * self(x), - adj_fn=lambda x: snp.conj(other) * self.adj(x), - input_dtype=self.input_dtype, - output_dtype=result_type(self.output_dtype, other), - ) + return self.__mul__(other) # scalar multiplication is commutative @_wrap_mul_div_scalar def __truediv__(self, other): diff --git a/scico/linop/_matrix.py b/scico/linop/_matrix.py index 662df0e49..5ae43b85f 100644 --- a/scico/linop/_matrix.py +++ b/scico/linop/_matrix.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2020-2023 by SCICO Developers +# Copyright (C) 2020-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -18,10 +18,10 @@ import numpy as np import jax.numpy as jnp -from jax.dtypes import result_type from jax.typing import ArrayLike import scico.numpy as snp +from scico.operator._operator import Operator from ._diag import Identity from ._linop import LinearOperator @@ -45,17 +45,17 @@ def wrapper(a, b): raise ValueError(f"Shapes {a.matrix_shape} and {b.shape} do not match.") + if isinstance(b, Operator): + if a.shape != b.shape: + raise ValueError(f"Shapes {a.shape} and {b.shape} do not match.") + if isinstance(b, LinearOperator): - if a.shape == b.shape: - return LinearOperator( - input_shape=a.input_shape, - output_shape=a.output_shape, - eval_fn=lambda x: op(a(x), b(x)), - input_dtype=a.input_dtype, - output_dtype=result_type(a.output_dtype, b.output_dtype), - ) + uwfunc = getattr(LinearOperator, func.__name__)._unwrapped + return uwfunc(a, b) - raise ValueError(f"Shapes {a.shape} and {b.shape} do not match.") + if isinstance(b, Operator): + uwfunc = getattr(Operator, func.__name__) + return uwfunc(a, b) raise TypeError(f"Operation {func.__name__} not defined between {type(a)} and {type(b)}.") diff --git a/scico/numpy/_blockarray.py b/scico/numpy/_blockarray.py index 72694608a..4c844baba 100644 --- a/scico/numpy/_blockarray.py +++ b/scico/numpy/_blockarray.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2020-2023 by SCICO Developers +# Copyright (C) 2020-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SPORCO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed @@ -61,7 +61,7 @@ def dtype(self): """Return the dtype of the blocks, which must currently be homogeneous. This allows `snp.zeros(x.shape, x.dtype)` to work without a mechanism - to handle to lists of dtypes. + to handle lists of dtypes. """ return self.arrays[0].dtype diff --git a/scico/numpy/_wrappers.py b/scico/numpy/_wrappers.py index b14be9006..024881eaa 100644 --- a/scico/numpy/_wrappers.py +++ b/scico/numpy/_wrappers.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2022-2023 by SCICO Developers +# Copyright (C) 2022-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SPORCO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed diff --git a/scico/operator/_operator.py b/scico/operator/_operator.py index eabdf47ee..9a1fe6cbe 100644 --- a/scico/operator/_operator.py +++ b/scico/operator/_operator.py @@ -1,4 +1,4 @@ -# Copyright (C) 2020-2023 by SCICO Developers +# Copyright (C) 2020-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -40,6 +40,9 @@ def _wrap_mul_div_scalar(func: Callable) -> Callable: func: should be either `.__mul__()`, `.__rmul__()`, or `.__truediv__()`. + Returns: + Wrapped version of `func`. + Raises: TypeError: If a binop with the form `binop(Operator, other)` is called and `other` is not a scalar. @@ -52,6 +55,8 @@ def wrapper(a, b): raise TypeError(f"Operation {func.__name__} not defined between {type(a)} and {type(b)}.") + wrapper._unwrapped = func # type: ignore + return wrapper diff --git a/scico/solver.py b/scico/solver.py index 02a76b487..69f218274 100644 --- a/scico/solver.py +++ b/scico/solver.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2020-2023 by SCICO Developers +# Copyright (C) 2020-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -723,7 +723,7 @@ def __init__( if isinstance(D, Diagonal): D = D.diagonal - if not D.ndim == 1: + if D.ndim != 1: raise ValueError("If Diagonal, D should have a 1D diagonal.") else: D = jnp.array(D) @@ -734,7 +734,7 @@ def __init__( W = snp.ones(A.shape[0], dtype=A.dtype) elif isinstance(W, Diagonal): W = W.diagonal - if not W.ndim == 1: + if W.ndim != 1: raise ValueError("If Diagonal, W should have a 1D diagonal.") elif not isinstance(W, Array): raise TypeError( diff --git a/scico/test/linop/test_binop.py b/scico/test/linop/test_binop.py new file mode 100644 index 000000000..826286305 --- /dev/null +++ b/scico/test/linop/test_binop.py @@ -0,0 +1,61 @@ +import operator as op + +import pytest + +import scico.numpy as snp +from scico import linop +from scico.operator import Abs, Operator + + +class TestBinaryOp: + def setup_method(self, method): + self.input_shape = (5,) + self.input_dtype = snp.float32 + + @pytest.mark.parametrize("operator", [op.add, op.sub]) + def test_case1(self, operator): + A = linop.Convolve( + snp.ones((2,)), input_shape=self.input_shape, input_dtype=self.input_dtype, mode="same" + ) + B = Abs(input_shape=self.input_shape, input_dtype=self.input_dtype) + + assert type(operator(A, B)) == Operator + assert type(operator(B, A)) == Operator + assert type(operator(2.0 * A, 3.0 * B)) == Operator + assert type(operator(2.0 * B, 3.0 * A)) == Operator + + @pytest.mark.parametrize("operator", [op.add, op.sub]) + def test_case2(self, operator): + A = linop.Convolve( + snp.ones((2,)), input_shape=self.input_shape, input_dtype=self.input_dtype, mode="same" + ) + B = linop.Identity(input_shape=self.input_shape, input_dtype=self.input_dtype) + + assert type(operator(A, B)) == linop.LinearOperator + assert type(operator(B, A)) == linop.LinearOperator + assert type(operator(2.0 * A, 3.0 * B)) == linop.LinearOperator + assert type(operator(2.0 * B, 3.0 * A)) == linop.LinearOperator + + @pytest.mark.parametrize("operator", [op.add, op.sub]) + def test_case3(self, operator): + A = linop.SingleAxisFiniteDifference( + input_shape=self.input_shape, input_dtype=self.input_dtype, circular=True + ) + B = linop.Identity(input_shape=self.input_shape, input_dtype=self.input_dtype) + + assert type(operator(A, B)) == linop.LinearOperator + assert type(operator(B, A)) == linop.LinearOperator + assert type(operator(2.0 * A, 3.0 * B)) == linop.LinearOperator + assert type(operator(2.0 * B, 3.0 * A)) == linop.LinearOperator + + @pytest.mark.parametrize("operator", [op.add, op.sub]) + def test_case4(self, operator): + A = linop.ScaledIdentity( + scalar=0.5, input_shape=self.input_shape, input_dtype=self.input_dtype + ) + B = linop.Identity(input_shape=self.input_shape, input_dtype=self.input_dtype) + + assert type(operator(A, B)) == linop.ScaledIdentity + assert type(operator(B, A)) == linop.ScaledIdentity + assert type(operator(2.0 * A, 3.0 * B)) == linop.ScaledIdentity + assert type(operator(2.0 * B, 3.0 * A)) == linop.ScaledIdentity diff --git a/scico/test/linop/test_diag.py b/scico/test/linop/test_diag.py index 303228170..9d8d7bdc7 100644 --- a/scico/test/linop/test_diag.py +++ b/scico/test/linop/test_diag.py @@ -187,3 +187,146 @@ def test_norm_except(self): D = linop.Diagonal(diagonal=diagonal) with pytest.raises(ValueError): n = D.norm(ord=3) + + +class TestScaledIdentity: + def setup_method(self, method): + self.key = jax.random.PRNGKey(12345) + + input_shapes = [(8,), (8, 12), ((3,), (4, 5))] + + @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) + @pytest.mark.parametrize("input_shape", input_shapes) + def test_eval(self, input_shape, input_dtype): + x, key = randn(input_shape, dtype=input_dtype, key=self.key) + scalar, key = randn((), dtype=input_dtype, key=key) + + Id = linop.ScaledIdentity(scalar=scalar, input_shape=input_shape, input_dtype=input_dtype) + assert (Id @ x).shape == Id.output_shape + snp.testing.assert_allclose(scalar * x, Id @ x, rtol=1e-5) + + @pytest.mark.parametrize("operator", [op.add, op.sub]) + @pytest.mark.parametrize("input_shape", input_shapes) + def test_binary_op(self, input_shape, operator): + input_dtype = np.float32 + diagonal, key = randn(input_shape, dtype=input_dtype, key=self.key) + x, key = randn(input_shape, dtype=input_dtype, key=key) + scalar, key = randn((), dtype=input_dtype, key=key) + + Id = linop.ScaledIdentity(scalar, input_shape=input_shape) + D = linop.Diagonal(diagonal=diagonal) + + IdD = operator(Id, D) + assert isinstance(IdD, linop.Diagonal) + snp.testing.assert_allclose(IdD @ x, operator(scalar, diagonal) * x, rtol=1e-6) + + DId = operator(D, Id) + assert isinstance(DId, linop.Diagonal) + snp.testing.assert_allclose(DId @ x, operator(diagonal, scalar) * x, rtol=1e-6) + + def test_scale(self): + input_shape = (5,) + input_dtype = np.float32 + scalar1, key = randn((), dtype=input_dtype, key=self.key) + scalar2, key = randn((), dtype=input_dtype, key=key) + + x, key = randn(input_shape, dtype=input_dtype, key=self.key) + Id = linop.ScaledIdentity(scalar=scalar1, input_shape=input_shape, input_dtype=input_dtype) + + sId = scalar2 * Id + assert isinstance(sId, linop.ScaledIdentity) + snp.testing.assert_allclose(sId @ x, scalar1 * scalar2 * x, rtol=1e-6) + + Ids = Id * scalar2 + assert isinstance(Ids, linop.ScaledIdentity) + snp.testing.assert_allclose(Ids @ x, scalar1 * scalar2 * x, rtol=1e-6) + + Idds = Id / scalar2 + assert isinstance(Idds, linop.ScaledIdentity) + snp.testing.assert_allclose(Idds @ x, x * scalar1 / scalar2, rtol=1e-6) + + @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) + @pytest.mark.parametrize("ord", [None, "fro", "nuc", -np.inf, np.inf, 1, -1, 2, -2]) + def test_norm(self, input_dtype, ord): + input_shape = (5,) + scalar, key = randn((), dtype=input_dtype, key=self.key) + + Id = linop.ScaledIdentity(scalar=scalar, input_shape=input_shape, input_dtype=input_dtype) + D = linop.Diagonal( + diagonal=scalar * snp.ones(input_shape), + input_shape=input_shape, + input_dtype=input_dtype, + ) + n1 = Id.norm(ord=ord) + n2 = D.norm(ord=ord) + snp.testing.assert_allclose(n1, n2, rtol=1e-6) + + def test_norm_except(self): + input_shape = (5,) + + Id = linop.Identity(input_shape=input_shape, input_dtype=np.float32) + with pytest.raises(ValueError): + n = Id.norm(ord=3) + + +class TestIdentity: + def setup_method(self, method): + self.key = jax.random.PRNGKey(12345) + + input_shapes = [(8,), (8, 12), ((3,), (4, 5))] + + @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) + @pytest.mark.parametrize("input_shape", input_shapes) + def test_eval(self, input_shape, input_dtype): + x, key = randn(input_shape, dtype=input_dtype, key=self.key) + + Id = linop.Identity(input_shape=input_shape, input_dtype=input_dtype) + assert (Id @ x).shape == Id.output_shape + snp.testing.assert_allclose(x, Id @ x, rtol=1e-5) + + @pytest.mark.parametrize("operator", [op.add, op.sub]) + @pytest.mark.parametrize("input_shape", input_shapes) + def test_binary_op(self, input_shape, operator): + input_dtype = np.float32 + diagonal, key = randn(input_shape, dtype=input_dtype, key=self.key) + scalar, key = randn((), dtype=input_dtype, key=key) + x, key = randn(input_shape, dtype=input_dtype, key=key) + + Id = linop.Identity(input_shape=input_shape) + Ids = linop.ScaledIdentity(scalar=scalar, input_shape=input_shape) + D = linop.Diagonal(diagonal=diagonal) + + IdD = operator(Id, D) + assert isinstance(IdD, linop.Diagonal) + snp.testing.assert_allclose(IdD @ x, operator(1.0, diagonal) * x, rtol=1e-6) + + DId = operator(D, Id) + assert isinstance(DId, linop.Diagonal) + snp.testing.assert_allclose(DId @ x, operator(diagonal, 1.0) * x, rtol=1e-6) + + IdIds = operator(Id, Ids) + assert isinstance(IdIds, linop.ScaledIdentity) + snp.testing.assert_allclose(IdIds @ x, operator(1.0, scalar) * x, rtol=1e-6) + + IdsId = operator(Ids, Id) + assert isinstance(IdsId, linop.ScaledIdentity) + snp.testing.assert_allclose(IdsId @ x, operator(scalar, 1.0) * x, rtol=1e-6) + + def test_scale(self): + input_shape = (5,) + input_dtype = np.float32 + scalar, key = randn((), dtype=input_dtype, key=self.key) + x, key = randn(input_shape, dtype=input_dtype, key=key) + Id = linop.Identity(input_shape=input_shape, input_dtype=input_dtype) + + sId = scalar * Id + assert isinstance(sId, linop.ScaledIdentity) + snp.testing.assert_allclose(sId @ x, scalar * x, rtol=1e-6) + + Ids = Id * scalar + assert isinstance(Ids, linop.ScaledIdentity) + snp.testing.assert_allclose(Ids @ x, scalar * x, rtol=1e-6) + + Idds = Id / scalar + assert isinstance(Idds, linop.ScaledIdentity) + snp.testing.assert_allclose(Idds @ x, x / scalar, rtol=1e-6)