Skip to content

Commit

Permalink
fix: minor update
Browse files Browse the repository at this point in the history
  • Loading branch information
ChieloNewctle committed Nov 6, 2023
1 parent c47ec6e commit e8da97d
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 10 deletions.
13 changes: 6 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,9 @@ assert state.is_nil()
### `GreedyTokenizer`

```python
vocab = ['歌曲', '聆听歌曲', '播放歌曲', '歌词', '查看歌词', '听歌', '曲折']
from general_sam import GeneralSAM, GreedyTokenizer, build_trie_from_chars

vocab = ['a', 'ab', 'b', 'bc', 'c', 'd', 'e', 'f', 'cd', 'abcde']
trie, token_to_trie_node = build_trie_from_chars(vocab)

trie_node_to_token = [-1] * trie.num_of_nodes()
Expand All @@ -208,12 +210,9 @@ tokenizer = GreedyTokenizer.from_sam_and_trie(sam, trie)
def tokenize(s: str):
return [(trie_node_to_token[i], j) for i, j in tokenizer.tokenize_str(s)]

assert tokenize('歌曲折') == [(0, 2), (-1, 1)]
assert tokenize('听歌曲') == [(5, 2), (-1, 1)]
assert tokenize('听歌曲折') == [(5, 2), (6, 2)]
assert tokenize('聆听歌曲折') == [(1, 4), (-1, 1)]
assert tokenize('查看歌词歌曲') == [(4, 4), (0, 2)]
assert tokenize('一起播放歌曲并共享歌词') == [(-1, 2), (2, 4), (-1, 3), (3, 2)]
assert tokenize('abcde') == [(9, 5)]
assert tokenize('abcdf') == [(1, 2), (8, 2), (7, 1)]
assert tokenize('abca') == [(1, 2), (4, 1), (0, 1)]
```

## License
Expand Down
7 changes: 5 additions & 2 deletions general_sam/general_sam.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class Trie:
) -> TrieNode: ...

class GeneralSAMState:
def is_in_str(self) -> bool: ...
def is_in_chars(self) -> bool: ...
def is_in_bytes(self) -> bool: ...
def get_node_id(self) -> GeneralSAMNodeID: ...
def is_nil(self) -> bool: ...
Expand Down Expand Up @@ -79,7 +79,7 @@ class GeneralSAM:
def from_bytes(s: bytes) -> 'GeneralSAM': ...
@staticmethod
def from_trie(trie: Trie) -> 'GeneralSAM': ...
def is_in_str(self) -> bool: ...
def is_in_chars(self) -> bool: ...
def is_in_bytes(self) -> bool: ...
def num_of_nodes(self) -> int: ...
def get_root_state(self) -> GeneralSAMState: ...
Expand All @@ -89,5 +89,8 @@ class GeneralSAM:
class GreedyTokenizer:
@staticmethod
def from_sam_and_trie(sam: GeneralSAM, trie: Trie) -> 'GreedyTokenizer': ...
def get_sam(self) -> GeneralSAM: ...
def is_in_chars(self) -> bool: ...
def is_in_bytes(self) -> bool: ...
def tokenize_str(self, s: str) -> Sequence[Tuple[TrieNodeID, int]]: ...
def tokenize_bytes(self, s: bytes) -> Sequence[Tuple[TrieNodeID, int]]: ...
12 changes: 12 additions & 0 deletions src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,18 @@ pub struct GreedyTokenizer(pub Arc<SharedGreedyTokenizer>);

#[pymethods]
impl GreedyTokenizer {
pub fn get_sam(&self) -> GeneralSAM {
GeneralSAM(self.0.borrow_sam().0.clone())
}

pub fn is_in_chars(&self) -> bool {
self.0.borrow_sam().is_in_chars()
}

pub fn is_in_bytes(&self) -> bool {
self.0.borrow_sam().is_in_bytes()
}

#[staticmethod]
pub fn from_sam_and_trie(sam: &GeneralSAM, trie: &Trie) -> PyResult<Self> {
SharedGreedyTokenizer::from_sam_and_trie(sam, trie)
Expand Down
4 changes: 4 additions & 0 deletions tests/test_general_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

def test_bytes_abcbc():
sam = GeneralSAM.from_bytes(b'abcbc')
assert sam.is_in_bytes()

state = sam.get_root_state()
state.feed_bytes(b'cbc')
Expand All @@ -15,6 +16,8 @@ def test_bytes_abcbc():

def test_chars_abcbc():
sam = GeneralSAM.from_chars('abcbc')
assert sam.is_in_chars()

state = sam.get_root_state()

state.feed_chars('b')
Expand All @@ -30,6 +33,7 @@ def test_chars_abcbc():
def test_simple_sam_from_trie():
trie, _ = build_trie_from_chars(['hello', 'Chielo'])
sam = GeneralSAM.from_trie(trie)
assert trie.is_in_chars() and sam.is_in_chars()

def fetch_state(s: str) -> GeneralSAMState:
state = sam.get_root_state()
Expand Down
61 changes: 60 additions & 1 deletion tests/test_greedy_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,29 @@
from general_sam import GeneralSAM, GreedyTokenizer, build_trie_from_chars
from general_sam import (
GeneralSAM,
GreedyTokenizer,
build_trie_from_bytes,
build_trie_from_chars,
)


def test_english_chars_tokenize():
vocab = ['a', 'ab', 'b', 'bc', 'c', 'd', 'e', 'f', 'cd', 'abcde']
trie, token_to_trie_node = build_trie_from_chars(vocab)

trie_node_to_token = [-1] * trie.num_of_nodes()
for i, j in enumerate(token_to_trie_node):
trie_node_to_token[j] = i

sam = GeneralSAM.from_trie(trie)
tokenizer = GreedyTokenizer.from_sam_and_trie(sam, trie)
assert tokenizer.is_in_chars()

def tokenize(s: str):
return [(trie_node_to_token[i], j) for i, j in tokenizer.tokenize_str(s)]

assert tokenize('abcde') == [(9, 5)]
assert tokenize('abcdf') == [(1, 2), (8, 2), (7, 1)]
assert tokenize('abca') == [(1, 2), (4, 1), (0, 1)]


def test_chinese_chars_tokenize():
Expand All @@ -11,6 +36,7 @@ def test_chinese_chars_tokenize():

sam = GeneralSAM.from_trie(trie)
tokenizer = GreedyTokenizer.from_sam_and_trie(sam, trie)
assert tokenizer.is_in_chars()

def tokenize(s: str):
return [(trie_node_to_token[i], j) for i, j in tokenizer.tokenize_str(s)]
Expand All @@ -21,3 +47,36 @@ def tokenize(s: str):
assert tokenize('聆听歌曲折') == [(1, 4), (-1, 1)]
assert tokenize('查看歌词歌曲') == [(4, 4), (0, 2)]
assert tokenize('一起播放歌曲并共享歌词') == [(-1, 2), (2, 4), (-1, 3), (3, 2)]


def test_chinese_bytes_tokenize():
vocab = ['歌曲', '聆听歌曲', '播放歌曲', '歌词', '查看歌词', '听歌', '曲折']
vocab = [i.encode() for i in vocab]
trie, token_to_trie_node = build_trie_from_bytes(vocab)

trie_node_to_token = [-1] * trie.num_of_nodes()
for i, j in enumerate(token_to_trie_node):
trie_node_to_token[j] = i

sam = GeneralSAM.from_trie(trie)
tokenizer = GreedyTokenizer.from_sam_and_trie(sam, trie)
assert tokenizer.is_in_bytes()

def tokenize_str(s: str):
return [trie_node_to_token[i] for i, _ in tokenizer.tokenize_str(s)]

def tokenize_bytes(s: str):
return [trie_node_to_token[i] for i, _ in tokenizer.tokenize_bytes(s.encode())]

def tokenize(s: str):
a = tokenize_str(s)
b = tokenize_bytes(s)
assert a == b
return a

assert tokenize('歌曲折') == [0, -1]
assert tokenize('听歌曲') == [5, -1]
assert tokenize('听歌曲折') == [5, 6]
assert tokenize('聆听歌曲折') == [1, -1]
assert tokenize('查看歌词歌曲') == [4, 0]
assert tokenize('一起播放歌曲并共享歌词') == [-1, 2, -1, 3]

0 comments on commit e8da97d

Please sign in to comment.