diff --git a/scico/test/test_numpy_util.py b/scico/test/test_numpy_util.py index be1c9f5ab..faab01dc5 100644 --- a/scico/test/test_numpy_util.py +++ b/scico/test/test_numpy_util.py @@ -1,16 +1,10 @@ -import warnings - import numpy as np -import jax.numpy as jnp - import pytest import scico.numpy as snp -from scico.numpy import BlockArray from scico.numpy.util import ( complex_dtype, - ensure_on_device, indexed_shape, is_complex_dtype, is_nested, @@ -24,31 +18,6 @@ from scico.random import randn -def test_ensure_on_device(): - # Used to restore the warnings after the context is used - with warnings.catch_warnings(): - # Ignores warning raised by ensure_on_device - warnings.filterwarnings(action="ignore", category=UserWarning) - - NP = np.ones(2) - SNP = snp.ones(2) - BA = snp.blockarray([NP, SNP]) - NP_, SNP_, BA_ = ensure_on_device(NP, SNP, BA) - - assert isinstance(NP_, jnp.ndarray) - - assert isinstance(SNP_, jnp.ndarray) - assert SNP.unsafe_buffer_pointer() == SNP_.unsafe_buffer_pointer() - - assert isinstance(BA_, BlockArray) - assert isinstance(BA_[0], jnp.ndarray) - assert isinstance(BA_[1], jnp.ndarray) - assert BA[1].unsafe_buffer_pointer() == BA_[1].unsafe_buffer_pointer() - - NP_ = ensure_on_device(NP) - assert isinstance(NP_, jnp.ndarray) - - def test_no_nan_divide_array(): x, key = randn((4,), dtype=np.float32) y, key = randn(x.shape, dtype=np.float32, key=key)