Skip to content

Commit

Permalink
Typecheck more with mypyc
Browse files Browse the repository at this point in the history
  • Loading branch information
LoganDark committed Jun 5, 2023
1 parent 44b52ba commit 320bd2b
Showing 1 changed file with 20 additions and 16 deletions.
36 changes: 20 additions & 16 deletions tokenizer/rwkv_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,34 +217,37 @@ def printTokens(self, tokens):
# Tokenizer #4 (fast) https://github.com/LoganDark
########################################################################################################

from typing import Generator
from typing import Generator, Iterable
from ast import literal_eval

class FastTokenizer:
__slots__ = ('tok2val', 'root')

def __init__(self, file_name):
tok2val: Dict[int, bytes]
root: Dict[int, Entry]

def __init__(self, file_name) -> None:
self.tok2val = {}
self.root = {}

with open(file_name, 'rt', encoding = 'utf-8') as file:
for line in file:
token, value = line.rstrip().split(' ', 1)
value, expected_len = value.rsplit(' ', 1)
value = literal_eval(value)
if isinstance(value, str): value = value.encode('utf-8')
token, value, expected_len = int(token), value, int(expected_len)
assert len(value) == expected_len
self.add_token(token, value)

def add_token(self, token: int, value: bytes):
token_str, value_repr = line.rstrip().split(' ', 1)
value_repr, len_str = value_repr.rsplit(' ', 1)
value_str: Union[bytes, str] = literal_eval(value_repr)
value = value_str if isinstance(value_str, bytes) else value_str.encode('utf-8')
assert len(value) == int(len_str)
self.add_token(int(token_str), value)

def add_token(self, token: int, value: bytes) -> None:
self.tok2val[token] = value
pos = self.root
for byte in value[:-1]: pos = pos.setdefault(byte, (None, {}))[1]
pos.setdefault(value[-1], (token, {}))

def next_token(self, src: bytes) -> Optional[int]:
last_token, last = None, self.root
last_token: Optional[int] = None
last = self.root
for i in range(0, len(src)):
if current := last.get(src[i]):
if token := current[0]: last_token = token
Expand All @@ -255,7 +258,8 @@ def next_token(self, src: bytes) -> Optional[int]:
def encode_bytes(self, src: bytes) -> Generator[int, None, None]:
start, stop = 0, len(src)
while start < stop:
last_token, last = None, self.root
last_token: Optional[int] = None
last = self.root

for i in range(start, stop):
if current := last.get(src[i]):
Expand All @@ -268,13 +272,13 @@ def encode_bytes(self, src: bytes) -> Generator[int, None, None]:
if last_token: yield last_token
else: break

def decode_bytes(self, tokens: list[int]) -> bytes:
return b''.join(map(self.tok2val.get, tokens))
def decode_bytes(self, tokens: Iterable[int]) -> bytes:
return b''.join(map(self.tok2val.__getitem__, tokens))

def encode(self, src: str) -> Generator[int, None, None]:
return self.encode_bytes(src.encode('utf-8'))

def decode(self, tokens: list[int]) -> str:
def decode(self, tokens: Iterable[int]) -> str:
return self.decode_bytes(tokens).decode('utf-8')

########################################################################################################
Expand Down

0 comments on commit 320bd2b

Please sign in to comment.