Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
HCookie committed Oct 3, 2024
1 parent ccc9219 commit 5f6021b
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions tests/test_condition.py → tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from hypothesis import strategies as st
from hypothesis.extra import numpy as npst

from anemoi.inference.condition import Condition
from anemoi.inference.state import State


@given(shape=npst.array_shapes(min_dims=2), data=st.data())
def test_from_numpy(shape, data):
# Test condition creation from numpy arrays
# Test state creation from numpy arrays
data_strategy = npst.arrays(
dtype=np.float32,
shape=shape,
Expand All @@ -24,11 +24,11 @@ def test_from_numpy(shape, data):
unique=True,
)
var_array = data.draw(var_strategy)
condition = Condition.from_numpy(data_array, var_array)
state = State.from_numpy(data_array, var_array)

assume(not np.isnan(data_array).any())
assume(np.isfinite(data_array).all())
assert np.allclose(condition.to_array(var_array), data_array)
assert np.allclose(state.to_array(var_array), data_array)

# Generate a permutation of indices
permutation = np.random.permutation(shape[0])
Expand All @@ -37,4 +37,4 @@ def test_from_numpy(shape, data):
new_data_array = data_array[permutation]
new_var_array = var_array[permutation]

assert np.allclose(condition.to_array(new_var_array), new_data_array)
assert np.allclose(state.to_array(new_var_array), new_data_array)

0 comments on commit 5f6021b

Please sign in to comment.