From b038881b4bc782b0250b41ae235bba566fc42383 Mon Sep 17 00:00:00 2001 From: Can Balioglu Date: Tue, 19 Sep 2023 15:09:55 -0400 Subject: [PATCH] Complete the implementation of LLaMA tokenizer (#57) --- src/fairseq2/models/llama/tokenizer.py | 32 ++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/src/fairseq2/models/llama/tokenizer.py b/src/fairseq2/models/llama/tokenizer.py index 5a0b533cf..d0edf699a 100644 --- a/src/fairseq2/models/llama/tokenizer.py +++ b/src/fairseq2/models/llama/tokenizer.py @@ -23,7 +23,7 @@ @final class LLaMATokenizer(TextTokenizer): - """Represents the tokenizer used by NLLB models.""" + """Represents the tokenizer used by LLaMA models.""" model: SentencePieceModel @@ -36,6 +36,10 @@ def __init__(self, pathname: PathLike) -> None: vocabulary_info = vocabulary_from_sentencepiece(self.model) + # LLaMA tokenizer has no PAD symbol defined in its SentencePiece model + # and uses EOS instead. + vocabulary_info.pad_idx = vocabulary_info.eos_idx + super().__init__(vocabulary_info) @finaloverride @@ -51,18 +55,38 @@ def create_encoder( """Create a token encoder. :param task: - Not used in LLaMA, defaults to ``None``. + Not used. :param lang: - Not used in LLaMA, defaults to ``None``. + Not used. :param mode: - Not used in LLaMA, defaults to ``None``. + Must be 'default' or 'prompt'. If ``None``, defaults to 'default'. :param device: The device on which to construct tensors. :param pin_memory: If ``True``, uses pinned memory while constructing tensors. """ + if task is not None: + raise ValueError(f"`task` must be `None`, but is '{task}' instead.") + + if lang is not None: + raise ValueError(f"`lang` must be `None`, but is '{lang}' instead.") + + if mode is None or mode == "default": + prefix_tokens = [""] + suffix_tokens = [""] + elif mode == "prompt": + prefix_tokens = [""] + # In prompt mode, we expect the generator to finish the sequence. + suffix_tokens = None + else: + raise ValueError( + f"`mode` must be 'default' or 'prompt', but is '{mode}' instead." + ) + return SentencePieceEncoder( self.model, + prefix_tokens=prefix_tokens, + suffix_tokens=suffix_tokens, device=device, pin_memory=pin_memory, )