Skip to content

Commit

Permalink
Improve test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
neilmehta24 committed Dec 22, 2024
1 parent 40f6ad1 commit a1056a5
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions tests/processors/test_base_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
try:
import mlx.core as mx

arrays["mlx"] = mx.array([[1, 2], [3, 4]], dtype=mx.bfloat16)
arrays["mlx"] = mx.array([[1, 2], [3, 4]], dtype=mx.float32)
arrays["mlx_bfloat16"] = mx.array([[1, 2], [3, 4]], dtype=mx.bfloat16)
except ImportError:
pass

Expand Down Expand Up @@ -59,7 +60,12 @@ def test_from_torch(array_type, processor):
torch_tensor = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
data = processor._from_torch(torch_tensor, type(arrays[array_type]))
assert isinstance(data, type(arrays[array_type]))
assert np.allclose(data, arrays[array_type])
if array_type == "mlx_bfloat16":
# For bfloat16, we expect the output to be float32 due to the conversion
assert data.dtype == mx.float32
assert np.allclose(np.array(data), np.array([[1, 2], [3, 4]], dtype=np.float32))
else:
assert np.allclose(data, arrays[array_type])


@pytest.mark.parametrize("array_type", arrays.keys())
Expand Down

0 comments on commit a1056a5

Please sign in to comment.