Skip to content

Commit

Permalink
Always move HF tozenizer encodings to the target device (#274)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremyfowers authored Jan 24, 2025
1 parent c1463b0 commit 30eb2eb
Showing 1 changed file with 22 additions and 2 deletions.
24 changes: 22 additions & 2 deletions src/lemonade/tools/huggingface_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from turnkeyml.state import State
import turnkeyml.common.status as status
from turnkeyml.tools import Tool, FirstTool
from lemonade.tools.adapter import ModelAdapter
from lemonade.tools.adapter import ModelAdapter, TokenizerAdapter
from lemonade.cache import Keys

# Command line interfaces for tools will use string inputs for data
Expand All @@ -32,6 +32,26 @@ def make_example_inputs(state: State) -> Dict:
return {"input_ids": inputs_ids}


class HuggingfaceTokenizerAdapter(TokenizerAdapter):
def __init__(self, tokenizer: transformers.AutoTokenizer, device: str):
super().__init__()
self.tokenizer = tokenizer
self.device = device

def __call__(self, prompt, **kwargs):
return self.tokenizer(prompt, **kwargs).to(self.device)

def decode(self, response, **kwargs):
return self.tokenizer.decode(response, **kwargs)

def batch_decode(self, tokens, **kwargs):
return self.tokenizer.batch_decode(tokens, **kwargs)

@property
def eos_token_id(self):
return self.tokenizer.eos_token_id


class HuggingfaceLoad(FirstTool):
"""
Load an LLM as a torch.nn.Module using the Hugging Face transformers
Expand Down Expand Up @@ -167,7 +187,7 @@ def run(

# Pass the model and inputs into state
state.model = model
state.tokenizer = tokenizer
state.tokenizer = HuggingfaceTokenizerAdapter(tokenizer, device)
state.dtype = dtype
state.checkpoint = checkpoint
state.device = device
Expand Down

0 comments on commit 30eb2eb

Please sign in to comment.