Skip to content

Commit

Permalink
Delete ensure_on_device function
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg committed Oct 31, 2023
1 parent 66bf302 commit 7b3ecfa
Showing 1 changed file with 0 additions and 31 deletions.
31 changes: 0 additions & 31 deletions scico/test/test_numpy_util.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
import warnings

import numpy as np

import jax.numpy as jnp

import pytest

import scico.numpy as snp
from scico.numpy import BlockArray
from scico.numpy.util import (
complex_dtype,
ensure_on_device,
indexed_shape,
is_complex_dtype,
is_nested,
Expand All @@ -24,31 +18,6 @@
from scico.random import randn


def test_ensure_on_device():
# Used to restore the warnings after the context is used
with warnings.catch_warnings():
# Ignores warning raised by ensure_on_device
warnings.filterwarnings(action="ignore", category=UserWarning)

NP = np.ones(2)
SNP = snp.ones(2)
BA = snp.blockarray([NP, SNP])
NP_, SNP_, BA_ = ensure_on_device(NP, SNP, BA)

assert isinstance(NP_, jnp.ndarray)

assert isinstance(SNP_, jnp.ndarray)
assert SNP.unsafe_buffer_pointer() == SNP_.unsafe_buffer_pointer()

assert isinstance(BA_, BlockArray)
assert isinstance(BA_[0], jnp.ndarray)
assert isinstance(BA_[1], jnp.ndarray)
assert BA[1].unsafe_buffer_pointer() == BA_[1].unsafe_buffer_pointer()

NP_ = ensure_on_device(NP)
assert isinstance(NP_, jnp.ndarray)


def test_no_nan_divide_array():
x, key = randn((4,), dtype=np.float32)
y, key = randn(x.shape, dtype=np.float32, key=key)
Expand Down

0 comments on commit 7b3ecfa

Please sign in to comment.