Skip to content

Commit

Permalink
Reduce memory requirements of Identity linop (#501)
Browse files Browse the repository at this point in the history
* Clean up

* Work in progress: avoid creating a potentially large diagonal array for every Identity linop

* Trivial edit

* Typo fix

* Resolve some bugs

* Remove dependence of behaviour on order of operands

* Update year

* Remove attempt to impose commutative behaviour

* Clean up post #500 bug fix

* Clean up

* Scaling of an Identity returns a LinearOperator, not a Diagonal

* Add some tests

* Add ScaledIdentity linop

* Make binary operator behaviour commutative

* Add some tests

* Clean up

* Resolve mypy errors

* Address PR review comment

* Fix bug identified in PR review

* Add tests

* Wrapper simplification in progress

* Clean up _wrap_add_sub. Partially addresses #502

* Remove need for overriding __rmul__

* Remove now-unneccesary __rmul__ definitions
  • Loading branch information
bwohlberg authored Feb 15, 2024
1 parent 1a66887 commit 6e46ca2
Show file tree
Hide file tree
Showing 12 changed files with 471 additions and 136 deletions.
5 changes: 3 additions & 2 deletions scico/linop/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2021-2023 by SCICO Developers
# Copyright (C) 2021-2024 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
Expand All @@ -12,7 +12,7 @@
from ._circconv import CircularConvolve
from ._convolve import Convolve, ConvolveByX
from ._dft import DFT
from ._diag import Diagonal, Identity
from ._diag import Diagonal, Identity, ScaledIdentity
from ._diff import FiniteDifference, SingleAxisFiniteDifference
from ._func import Crop, Pad, Reshape, Slice, Sum, Transpose, linop_from_function
from ._linop import ComposedLinearOperator, LinearOperator
Expand All @@ -35,6 +35,7 @@
"Pad",
"Crop",
"Reshape",
"ScaledIdentity",
"Slice",
"Sum",
"Transpose",
Expand Down
12 changes: 3 additions & 9 deletions scico/linop/_circconv.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2021-2023 by SCICO Developers
# Copyright (C) 2021-2024 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
Expand All @@ -8,8 +8,6 @@
"""Circular convolution linear operators."""

import math
import operator
from functools import partial
from typing import Optional, Sequence, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -205,7 +203,7 @@ def _adj(self, x: snp.Array) -> snp.Array: # type: ignore
H_adj_x = H_adj_x.real
return H_adj_x

@partial(_wrap_add_sub, op=operator.add)
@_wrap_add_sub
def __add__(self, other):
if self.ndims != other.ndims:
raise ValueError(f"Incompatible ndims: {self.ndims} != {other.ndims}.")
Expand All @@ -218,7 +216,7 @@ def __add__(self, other):
h_is_dft=True,
)

@partial(_wrap_add_sub, op=operator.sub)
@_wrap_add_sub
def __sub__(self, other):
if self.ndims != other.ndims:
raise ValueError(f"Incompatible ndims: {self.ndims} != {other.ndims}.")
Expand All @@ -241,10 +239,6 @@ def __mul__(self, scalar):
h_is_dft=True,
)

@_wrap_mul_div_scalar
def __rmul__(self, scalar):
return self * scalar

@_wrap_mul_div_scalar
def __truediv__(self, scalar):
return CircularConvolve(
Expand Down
35 changes: 5 additions & 30 deletions scico/linop/_convolve.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2020-2023 by SCICO Developers
# Copyright (C) 2020-2024 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
Expand All @@ -12,9 +12,6 @@
# see https://www.python.org/dev/peps/pep-0563/
from __future__ import annotations

import operator
from functools import partial

import numpy as np

from jax.dtypes import result_type
Expand Down Expand Up @@ -85,7 +82,7 @@ def __init__(
def _eval(self, x: snp.Array) -> snp.Array:
return convolve(x, self.h, mode=self.mode)

@partial(_wrap_add_sub, op=operator.add)
@_wrap_add_sub
def __add__(self, other):
if self.mode != other.mode:
raise ValueError(f"Incompatible modes: {self.mode} != {other.mode}.")
Expand All @@ -102,7 +99,7 @@ def __add__(self, other):

raise ValueError(f"Incompatible shapes: {self.shape} != {other.shape}.")

@partial(_wrap_add_sub, op=operator.sub)
@_wrap_add_sub
def __sub__(self, other):
if self.mode != other.mode:
raise ValueError(f"Incompatible modes: {self.mode} != {other.mode}.")
Expand All @@ -129,17 +126,6 @@ def __mul__(self, scalar):
adj_fn=lambda x: snp.conj(scalar) * self.adj(x),
)

@_wrap_mul_div_scalar
def __rmul__(self, scalar):
return Convolve(
h=self.h * scalar,
input_shape=self.input_shape,
input_dtype=result_type(self.input_dtype, type(scalar)),
mode=self.mode,
output_shape=self.output_shape,
adj_fn=lambda x: snp.conj(scalar) * self.adj(x),
)

@_wrap_mul_div_scalar
def __truediv__(self, scalar):
return Convolve(
Expand Down Expand Up @@ -216,7 +202,7 @@ def __init__(
def _eval(self, h: snp.Array) -> snp.Array:
return convolve(self.x, h, mode=self.mode)

@partial(_wrap_add_sub, op=operator.add)
@_wrap_add_sub
def __add__(self, other):
if self.mode != other.mode:
raise ValueError(f"Incompatible modes: {self.mode} != {other.mode}.")
Expand All @@ -231,7 +217,7 @@ def __add__(self, other):
)
raise ValueError(f"Incompatible shapes: {self.shape} != {other.shape}.")

@partial(_wrap_add_sub, op=operator.sub)
@_wrap_add_sub
def __sub__(self, other):
if self.mode != other.mode:
raise ValueError(f"Incompatible modes: {self.mode} != {other.mode}.")
Expand Down Expand Up @@ -259,17 +245,6 @@ def __mul__(self, scalar):
adj_fn=lambda x: snp.conj(scalar) * self.adj(x),
)

@_wrap_mul_div_scalar
def __rmul__(self, scalar):
return ConvolveByX(
x=self.x * scalar,
input_shape=self.input_shape,
input_dtype=result_type(self.input_dtype, type(scalar)),
mode=self.mode,
output_shape=self.output_shape,
adj_fn=lambda x: snp.conj(scalar) * self.adj(x),
)

@_wrap_mul_div_scalar
def __truediv__(self, scalar):
return ConvolveByX(
Expand Down
Loading

0 comments on commit 6e46ca2

Please sign in to comment.