From fb43f4fb7f711cc765e78671e991e1eecd51d38b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Mon, 9 Dec 2024 14:55:12 +0100 Subject: [PATCH] Fix conversion from MLX to torch tensor --- outlines/processors/base_logits_processor.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/outlines/processors/base_logits_processor.py b/outlines/processors/base_logits_processor.py index eec7de121..44b55af2e 100644 --- a/outlines/processors/base_logits_processor.py +++ b/outlines/processors/base_logits_processor.py @@ -107,9 +107,12 @@ def _to_torch(tensor_like: Array) -> torch.Tensor: return torch.tensor(tensor_like) elif is_mlx_array_type(type(tensor_like)): - # mlx -> torch -> mlx conversion docs: - # https://ml-explore.github.io/mlx/build/html/usage/numpy.html - return torch.from_dlpack(tensor_like) + import mlx.core as mx + + # https://ml-explore.github.io/mlx/build/html/usage/numpy.html#pytorch + return torch.from_dlpack( + np.array(tensor_like.astype(mx.float32), copy=False) + ) elif is_jax_array_type(type(tensor_like)): import jax