Skip to content

Commit

Permalink
Resolve #483 (#484)
Browse files Browse the repository at this point in the history
* Add a test for issue to be resolved

* Distinct handling for functions without any return value

* Resolve oversight identified in PR review
  • Loading branch information
bwohlberg authored Dec 14, 2023
1 parent e874051 commit 9b32636
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 26 deletions.
2 changes: 1 addition & 1 deletion scico/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
_wrappers.wrap_recursively(vars(), reduction_functions, _wrappers.add_full_reduction)

# wrap testing funcs
_wrappers.wrap_recursively(vars(), testing_functions, _wrappers.map_func_over_blocks)
_wrappers.wrap_recursively(vars(), testing_functions, _wrappers.map_void_func_over_blocks)

# clean up
del np, jnp, _wrappers
65 changes: 47 additions & 18 deletions scico/numpy/_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,34 +96,63 @@ def mapped(*args, **kwargs):
return mapped


def _num_blocks_in_args(*args, **kwargs):
"""Count the number of BlockArray arguments."""
first_ba_arg = next((arg for arg in args if isinstance(arg, BlockArray)), None)
if first_ba_arg is None:
first_ba_kwarg = next((v for k, v in kwargs.items() if isinstance(v, BlockArray)), None)
if first_ba_kwarg is None:
num_blocks = 0
else:
num_blocks = len(first_ba_kwarg)
else:
num_blocks = len(first_ba_arg)
return num_blocks


def _block_args_kwargs(num_blocks, *args, **kwargs):
"""Construct nested args/kwargs for each BlockArray block."""
new_args = []
new_kwargs = []
for i in range(num_blocks):
new_args.append([arg[i] if isinstance(arg, BlockArray) else arg for arg in args])
new_kwargs.append(
{k: (v[i] if isinstance(v, BlockArray) else v) for k, v in kwargs.items()}
)
return new_args, new_kwargs


def map_func_over_blocks(func):
"""Wrap a function so that it maps over all of its BlockArray
arguments.
"""

@wraps(func)
def mapped(*args, **kwargs):
num_blocks = _num_blocks_in_args(*args, **kwargs)
if num_blocks == 0:
return func(*args, **kwargs) # no BlockArray arguments, so no mapping
new_args, new_kwargs = _block_args_kwargs(num_blocks, *args, **kwargs)
# run the function num_blocks times, return results in a BlockArray
return BlockArray(func(*new_args[i], **new_kwargs[i]) for i in range(num_blocks))

first_ba_arg = next((arg for arg in args if isinstance(arg, BlockArray)), None)
if first_ba_arg is None:
first_ba_kwarg = next((v for k, v in kwargs.items() if isinstance(v, BlockArray)), None)
if first_ba_kwarg is None:
return func(*args, **kwargs) # no BlockArray arguments, so no mapping
num_blocks = len(first_ba_kwarg)
else:
num_blocks = len(first_ba_arg)
return mapped

# build a list of new args and kwargs, one for each block
new_args_list = []
new_kwargs_list = []
for i in range(num_blocks):
new_args_list.append([arg[i] if isinstance(arg, BlockArray) else arg for arg in args])
new_kwargs_list.append(
{k: (v[i] if isinstance(v, BlockArray) else v) for k, v in kwargs.items()}
)

# run the function num_blocks times, return results in a BlockArray
return BlockArray(func(*new_args_list[i], **new_kwargs_list[i]) for i in range(num_blocks))
def map_void_func_over_blocks(func):
"""Wrap a function without a return value so that it maps over all
of its BlockArray arguments.
"""

@wraps(func)
def mapped(*args, **kwargs):
num_blocks = _num_blocks_in_args(*args, **kwargs)
if num_blocks == 0:
func(*args, **kwargs) # no BlockArray arguments, so no mapping
else:
new_args, new_kwargs = _block_args_kwargs(num_blocks, *args, **kwargs)
# run the function num_blocks times
[func(*new_args[i], **new_kwargs[i]) for i in range(num_blocks)]

return mapped

Expand Down
25 changes: 18 additions & 7 deletions scico/test/test_blockarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@

import scico.numpy as snp
from scico.numpy import BlockArray
from scico.numpy._wrapped_function_lists import testing_functions
from scico.numpy.testing import assert_array_equal
from scico.random import randn
from scico.util import rgetattr

math_ops = [op.add, op.sub, op.mul, op.truediv, op.pow] # op.floordiv doesn't work on complex
comp_ops = [op.le, op.lt, op.ge, op.gt, op.eq]
Expand Down Expand Up @@ -86,7 +88,8 @@ def test_ba_ba_operator(test_operator_obj, operator):
snp.testing.assert_allclose(x, y)


# Testing the @ interface for blockarrays of same size, and a blockarray and flattened ndarray/devicearray
# Testing the @ interface for blockarrays of same size, and a blockarray and flattened
# ndarray/devicearray
def test_ba_ba_matmul(test_operator_obj):
a = test_operator_obj.a
b = test_operator_obj.d
Expand Down Expand Up @@ -135,20 +138,20 @@ def test_ndim(test_operator_obj):


def test_getitem(test_operator_obj):
# Make a length-4 blockarray
# make a length-4 blockarray
a0 = test_operator_obj.a0
a1 = test_operator_obj.a1
b0 = test_operator_obj.b0
b1 = test_operator_obj.b1
x = BlockArray([a0, a1, b0, b1])

# Positive indexing
# positive indexing
np.testing.assert_allclose(x[0], a0)
np.testing.assert_allclose(x[1], a1)
np.testing.assert_allclose(x[2], b0)
np.testing.assert_allclose(x[3], b1)

# Negative indexing
# negative indexing
np.testing.assert_allclose(x[-4], a0)
np.testing.assert_allclose(x[-3], a1)
np.testing.assert_allclose(x[-2], b0)
Expand Down Expand Up @@ -193,9 +196,7 @@ def test_ba_ba_dot(test_operator_obj, operator):
snp.testing.assert_allclose(x, y)


###############################################################################
# Reduction tests
###############################################################################
# reduction tests
reduction_funcs = [
snp.sum,
snp.linalg.norm,
Expand Down Expand Up @@ -315,6 +316,16 @@ def test_full_nodtype(self):
assert snp.all(x == fill_value)


# testing function tests
@pytest.mark.parametrize("func", testing_functions)
def test_test_func(func):
a = snp.array([1.0, 2.0])
b = snp.blockarray((a, a))
f = rgetattr(snp, func)
retval = f(b, b)
assert retval is None


# tests added for the BlockArray refactor
@pytest.fixture
def x():
Expand Down

0 comments on commit 9b32636

Please sign in to comment.