Skip to content

Commit

Permalink
remove cuda and direct use jax dlpack
Browse files Browse the repository at this point in the history
  • Loading branch information
sky-2002 committed Oct 14, 2024
1 parent 384ebad commit f77c000
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions outlines/processors/base_logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _to_torch(tensor_like: Array) -> torch.Tensor:
import jax

torch_tensor = torch.from_dlpack(jax.dlpack.to_dlpack(tensor_like))
return torch_tensor.cuda()
return torch_tensor

else:
raise TypeError(
Expand Down Expand Up @@ -148,7 +148,7 @@ def _from_torch(tensor: torch.Tensor, target_type: Type) -> Array:
elif is_jax_array_type(target_type):
import jax

return jax.numpy.from_dlpack(tensor)
return jax.dlpack.from_dlpack(tensor)

else:
raise TypeError(
Expand Down

0 comments on commit f77c000

Please sign in to comment.