Skip to content

Commit

Permalink
Remove tests that are greatly complicated by the deprecation of jax d…
Browse files Browse the repository at this point in the history
…evice method
  • Loading branch information
bwohlberg committed Dec 13, 2023
1 parent d36610a commit 3f44eaa
Showing 1 changed file with 0 additions and 19 deletions.
19 changes: 0 additions & 19 deletions scico/test/test_solver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numpy as np

import jax
from jax.scipy.linalg import block_diag

import pytest
Expand Down Expand Up @@ -225,24 +224,6 @@ def f(x):
assert out.x.shape == x.shape
np.testing.assert_allclose(out.x.ravel(), expected, rtol=5e-4)

# check if minimize returns the object to the proper device
devices = jax.devices()

# for default device
x0 = jax.device_put(snp.zeros_like(x), devices[0])
out = solver.minimize(f, x0=x0, method=method)
assert out.x.device() == devices[0]
assert out.x.shape == x0.shape
np.testing.assert_allclose(out.x.ravel(), expected, rtol=5e-4)

# if more than one device is present
if len(devices) > 1:
x0 = jax.device_put(snp.zeros_like(x), devices[1])
out = solver.minimize(f, x0=x0, method=method)
assert out.x.device() == devices[1]
assert out.x.shape == x0.shape
np.testing.assert_allclose(out.x.ravel(), expected, rtol=5e-4)


def test_split_join_array():
x, key = random.randn((4, 4), dtype=np.complex64)
Expand Down

0 comments on commit 3f44eaa

Please sign in to comment.