Skip to content

Commit

Permalink
Fix utf-8 decode errors in tiktoken wrapper (#792)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Dec 8, 2023
1 parent 2017c02 commit 75cc1e1
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion llmfoundry/tokenizers/tiktoken.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(self,
eos_token: Optional[str] = '<|endoftext|>',
bos_token: Optional[str] = '<|endoftext|>',
pad_token: Optional[str] = None,
errors: str = 'replace',
**kwargs: Any):
"""Constructor creates a tiktoken tokenizer to use as the underlying.
Expand All @@ -78,6 +79,9 @@ def __init__(self,
eos_token (Optional[str], optional): The eos token. Defaults to '<|endoftext|>'.
bos_token (Optional[str], optional): The bos token. Defaults to '<|endoftext|>'.
pad_token (Optional[str], optional): The pad token. Defaults to None.
errors (str, optional): Paradigm to follow when decoding bytes to UTF-8. See
[bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
Defaults to `"replace"`.
"""
try:
import tiktoken
Expand Down Expand Up @@ -126,6 +130,7 @@ def pickle_Encoding(enc: Encoding):

self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
self.errors = errors

self.decoder: Dict[int, str] = {}
for i in range(self.encoding.n_vocab):
Expand Down Expand Up @@ -155,6 +160,7 @@ def pickle_Encoding(enc: Encoding):
eos_token=eos_token,
bos_token=bos_token,
pad_token=pad_token,
errors=errors,
**kwargs)

@property
Expand Down Expand Up @@ -252,7 +258,8 @@ def _convert_id_to_token(self, index: int) -> Optional[str]:
def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""Converts a sequence of tokens (string) in a single string."""
text = ''.join(tokens)
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8')
text = bytearray([self.byte_decoder[c] for c in text
]).decode('utf-8', errors=self.errors)
return text

def build_inputs_with_special_tokens(
Expand Down

0 comments on commit 75cc1e1

Please sign in to comment.