Skip to content

Commit

Permalink
Fix conversion from MLX to torch tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Dec 9, 2024
1 parent 147b03e commit fb43f4f
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions outlines/processors/base_logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,12 @@ def _to_torch(tensor_like: Array) -> torch.Tensor:
return torch.tensor(tensor_like)

elif is_mlx_array_type(type(tensor_like)):
# mlx -> torch -> mlx conversion docs:
# https://ml-explore.github.io/mlx/build/html/usage/numpy.html
return torch.from_dlpack(tensor_like)
import mlx.core as mx

# https://ml-explore.github.io/mlx/build/html/usage/numpy.html#pytorch
return torch.from_dlpack(
np.array(tensor_like.astype(mx.float32), copy=False)
)

elif is_jax_array_type(type(tensor_like)):
import jax
Expand Down

0 comments on commit fb43f4f

Please sign in to comment.