diff --git a/scico/test/test_solver.py b/scico/test/test_solver.py index c14c5e977..68d670c41 100644 --- a/scico/test/test_solver.py +++ b/scico/test/test_solver.py @@ -1,6 +1,5 @@ import numpy as np -import jax from jax.scipy.linalg import block_diag import pytest @@ -225,24 +224,6 @@ def f(x): assert out.x.shape == x.shape np.testing.assert_allclose(out.x.ravel(), expected, rtol=5e-4) - # check if minimize returns the object to the proper device - devices = jax.devices() - - # for default device - x0 = jax.device_put(snp.zeros_like(x), devices[0]) - out = solver.minimize(f, x0=x0, method=method) - assert out.x.device() == devices[0] - assert out.x.shape == x0.shape - np.testing.assert_allclose(out.x.ravel(), expected, rtol=5e-4) - - # if more than one device is present - if len(devices) > 1: - x0 = jax.device_put(snp.zeros_like(x), devices[1]) - out = solver.minimize(f, x0=x0, method=method) - assert out.x.device() == devices[1] - assert out.x.shape == x0.shape - np.testing.assert_allclose(out.x.ravel(), expected, rtol=5e-4) - def test_split_join_array(): x, key = random.randn((4, 4), dtype=np.complex64)