Skip to content

Commit

Permalink
Resolve type checking issues
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg committed Dec 6, 2023
1 parent 4312239 commit d0ad4e4
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
10 changes: 5 additions & 5 deletions scico/linop/_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import operator
from functools import partial
from typing import List, Optional, Union
from typing import Optional, Sequence, Union

import scico.numpy as snp
from scico.numpy import Array, BlockArray
Expand Down Expand Up @@ -48,7 +48,7 @@ class VerticalStack(VStack, LinearOperator):

def __init__(
self,
ops: List[LinearOperator],
ops: Sequence[LinearOperator],
collapse_output: Optional[bool] = True,
jit: bool = True,
**kwargs,
Expand All @@ -66,7 +66,7 @@ def __init__(
super().__init__(ops=ops, collapse_output=collapse_output, jit=jit, **kwargs)

def _adj(self, y: Union[Array, BlockArray]) -> Array: # type: ignore
return sum([op.adj(y_block) for y_block, op in zip(y, self.ops)])
return sum([op.adj(y_block) for y_block, op in zip(y, self.ops)]) # type: ignore

@partial(_wrap_add_sub, op=operator.add)
def __add__(self, other):
Expand Down Expand Up @@ -126,7 +126,7 @@ class DiagonalStack(DStack, LinearOperator):

def __init__(
self,
ops: List[LinearOperator],
ops: Sequence[LinearOperator],
collapse_input: Optional[bool] = True,
collapse_output: Optional[bool] = True,
jit: bool = True,
Expand All @@ -151,7 +151,7 @@ def __init__(
)

def _adj(self, y: Union[Array, BlockArray]) -> Union[Array, BlockArray]: # type: ignore
result = tuple(op.T @ y_n for op, y_n in zip(self.ops, y))
result = tuple(op.T @ y_n for op, y_n in zip(self.ops, y)) # type: ignore
if self.collapse_input:
return snp.stack(result)
return snp.blockarray(result)
8 changes: 4 additions & 4 deletions scico/operator/_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from __future__ import annotations

from typing import List, Optional, Sequence, Tuple, Union
from typing import Optional, Sequence, Tuple, Union

import numpy as np

Expand Down Expand Up @@ -78,7 +78,7 @@ class VerticalStack(Operator):

def __init__(
self,
ops: List[Operator],
ops: Sequence[Operator],
collapse_output: Optional[bool] = True,
jit: bool = True,
**kwargs,
Expand Down Expand Up @@ -116,7 +116,7 @@ def __init__(
)

@staticmethod
def check_if_stackable(ops: List[Operator]):
def check_if_stackable(ops: Sequence[Operator]):
"""Check that input ops are suitable for stack creation."""
if not isinstance(ops, (list, tuple)):
raise ValueError("Expected a list of Operator.")
Expand Down Expand Up @@ -227,7 +227,7 @@ class DiagonalStack(Operator):

def __init__(
self,
ops: List[Operator],
ops: Sequence[Operator],
collapse_input: Optional[bool] = True,
collapse_output: Optional[bool] = True,
jit: bool = True,
Expand Down

0 comments on commit d0ad4e4

Please sign in to comment.