diff --git a/tests/processors/test_base_processor.py b/tests/processors/test_base_processor.py index 1413e6cea..d2a1e1af2 100644 --- a/tests/processors/test_base_processor.py +++ b/tests/processors/test_base_processor.py @@ -17,7 +17,8 @@ try: import mlx.core as mx - arrays["mlx"] = mx.array([[1, 2], [3, 4]], dtype=mx.bfloat16) + arrays["mlx"] = mx.array([[1, 2], [3, 4]], dtype=mx.float32) + arrays["mlx_bfloat16"] = mx.array([[1, 2], [3, 4]], dtype=mx.bfloat16) except ImportError: pass @@ -59,7 +60,12 @@ def test_from_torch(array_type, processor): torch_tensor = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32) data = processor._from_torch(torch_tensor, type(arrays[array_type])) assert isinstance(data, type(arrays[array_type])) - assert np.allclose(data, arrays[array_type]) + if array_type == "mlx_bfloat16": + # For bfloat16, we expect the output to be float32 due to the conversion + assert data.dtype == mx.float32 + assert np.allclose(np.array(data), np.array([[1, 2], [3, 4]], dtype=np.float32)) + else: + assert np.allclose(data, arrays[array_type]) @pytest.mark.parametrize("array_type", arrays.keys())