diff --git a/thinc/tests/backends/test_ops.py b/thinc/tests/backends/test_ops.py index d5235ecc3..9762ee26a 100644 --- a/thinc/tests/backends/test_ops.py +++ b/thinc/tests/backends/test_ops.py @@ -787,11 +787,11 @@ def test_gemm_out_used(cpu_ops): @pytest.mark.parametrize("cpu_ops", CPU_OPS) -@settings(max_examples=MAX_EXAMPLES, deadline=None) -@given(X=strategies.arrays_BI()) -def test_flatten_unflatten_roundtrip(cpu_ops, X): +@settings(max_examples=MAX_EXAMPLES * 2, deadline=None) +@given(X=strategies.arrays_BI(dtype="i") | strategies.arrays_BI(dtype="f")) +def test_flatten_unflatten_roundtrip(cpu_ops: NumpyOps, X: numpy.ndarray): flat = cpu_ops.flatten([x for x in X]) - assert flat.ndim == 1 + assert flat.ndim == X.ndim - 1 unflat = cpu_ops.unflatten(flat, [len(x) for x in X]) assert_allclose(X, unflat) flat2 = cpu_ops.flatten([x for x in X], pad=1, dtype="f") diff --git a/thinc/tests/strategies.py b/thinc/tests/strategies.py index bc12975aa..67ca23b4f 100644 --- a/thinc/tests/strategies.py +++ b/thinc/tests/strategies.py @@ -3,6 +3,7 @@ from hypothesis.strategies import floats, integers, just, tuples from thinc.api import Linear, NumpyOps +from thinc.types import DTypes def get_ops(): @@ -34,8 +35,8 @@ def shapes(min_rows=1, max_rows=100, min_cols=1, max_cols=100): return tuples(lengths(lo=min_rows, hi=max_rows), lengths(lo=min_cols, hi=max_cols)) -def ndarrays_of_shape(shape, lo=-10.0, hi=10.0, dtype="float32", width=32): - if dtype.startswith("float"): +def ndarrays_of_shape(shape, lo=-10.0, hi=10.0, dtype: DTypes = "float32", width=32): + if dtype.startswith("f"): return arrays( dtype, shape=shape, elements=floats(min_value=lo, max_value=hi, width=width) ) @@ -49,18 +50,18 @@ def ndarrays(min_len=0, max_len=10, min_val=-10.0, max_val=10.0): ) -def arrays_BI(min_B=1, max_B=10, min_I=1, max_I=100): +def arrays_BI(min_B=1, max_B=10, min_I=1, max_I=100, dtype: DTypes = "float32"): shapes = tuples(lengths(lo=min_B, hi=max_B), lengths(lo=min_I, hi=max_I)) - return shapes.flatmap(ndarrays_of_shape) + return shapes.flatmap(lambda shape: ndarrays_of_shape(shape, dtype=dtype)) -def arrays_BOP(min_B=1, max_B=10, min_O=1, max_O=100, min_P=1, max_P=5): +def arrays_BOP(min_B=1, max_B=10, min_O=1, max_O=100, min_P=1, max_P=5, dtype: DTypes = "float32"): shapes = tuples( lengths(lo=min_B, hi=max_B), lengths(lo=min_O, hi=max_O), lengths(lo=min_P, hi=max_P), ) - return shapes.flatmap(ndarrays_of_shape) + return shapes.flatmap(lambda shape: ndarrays_of_shape(shape, dtype=dtype)) def arrays_BOP_BO(min_B=1, max_B=10, min_O=1, max_O=100, min_P=1, max_P=5):