From f77c000884f4a4048d0a0d608ec08a1864c95d86 Mon Sep 17 00:00:00 2001 From: Aakash Thatte Date: Mon, 14 Oct 2024 21:20:55 +0530 Subject: [PATCH] remove cuda and direct use jax dlpack --- outlines/processors/base_logits_processor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/outlines/processors/base_logits_processor.py b/outlines/processors/base_logits_processor.py index eee594e6f..eec7de121 100644 --- a/outlines/processors/base_logits_processor.py +++ b/outlines/processors/base_logits_processor.py @@ -115,7 +115,7 @@ def _to_torch(tensor_like: Array) -> torch.Tensor: import jax torch_tensor = torch.from_dlpack(jax.dlpack.to_dlpack(tensor_like)) - return torch_tensor.cuda() + return torch_tensor else: raise TypeError( @@ -148,7 +148,7 @@ def _from_torch(tensor: torch.Tensor, target_type: Type) -> Array: elif is_jax_array_type(target_type): import jax - return jax.numpy.from_dlpack(tensor) + return jax.dlpack.from_dlpack(tensor) else: raise TypeError(