Skip to content

Commit

Permalink
Complete the implementation of LLaMA tokenizer (#57)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu authored Sep 19, 2023
1 parent 0fc5de7 commit b038881
Showing 1 changed file with 28 additions and 4 deletions.
32 changes: 28 additions & 4 deletions src/fairseq2/models/llama/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

@final
class LLaMATokenizer(TextTokenizer):
"""Represents the tokenizer used by NLLB models."""
"""Represents the tokenizer used by LLaMA models."""

model: SentencePieceModel

Expand All @@ -36,6 +36,10 @@ def __init__(self, pathname: PathLike) -> None:

vocabulary_info = vocabulary_from_sentencepiece(self.model)

# LLaMA tokenizer has no PAD symbol defined in its SentencePiece model
# and uses EOS instead.
vocabulary_info.pad_idx = vocabulary_info.eos_idx

super().__init__(vocabulary_info)

@finaloverride
Expand All @@ -51,18 +55,38 @@ def create_encoder(
"""Create a token encoder.
:param task:
Not used in LLaMA, defaults to ``None``.
Not used.
:param lang:
Not used in LLaMA, defaults to ``None``.
Not used.
:param mode:
Not used in LLaMA, defaults to ``None``.
Must be 'default' or 'prompt'. If ``None``, defaults to 'default'.
:param device:
The device on which to construct tensors.
:param pin_memory:
If ``True``, uses pinned memory while constructing tensors.
"""
if task is not None:
raise ValueError(f"`task` must be `None`, but is '{task}' instead.")

if lang is not None:
raise ValueError(f"`lang` must be `None`, but is '{lang}' instead.")

if mode is None or mode == "default":
prefix_tokens = ["<s>"]
suffix_tokens = ["</s>"]
elif mode == "prompt":
prefix_tokens = ["<s>"]
# In prompt mode, we expect the generator to finish the sequence.
suffix_tokens = None
else:
raise ValueError(
f"`mode` must be 'default' or 'prompt', but is '{mode}' instead."
)

return SentencePieceEncoder(
self.model,
prefix_tokens=prefix_tokens,
suffix_tokens=suffix_tokens,
device=device,
pin_memory=pin_memory,
)
Expand Down

0 comments on commit b038881

Please sign in to comment.