From e8da97d10d3f94453aeff10385d0d717245ffacb Mon Sep 17 00:00:00 2001 From: Chielo Newctle Date: Mon, 6 Nov 2023 15:41:41 +0800 Subject: [PATCH] fix: minor update --- README.md | 13 ++++---- general_sam/general_sam.pyi | 7 ++-- src/tokenizer.rs | 12 +++++++ tests/test_general_sam.py | 4 +++ tests/test_greedy_tokenizer.py | 61 +++++++++++++++++++++++++++++++++- 5 files changed, 87 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 7be596e..55c7ac5 100644 --- a/README.md +++ b/README.md @@ -195,7 +195,9 @@ assert state.is_nil() ### `GreedyTokenizer` ```python -vocab = ['歌曲', '聆听歌曲', '播放歌曲', '歌词', '查看歌词', '听歌', '曲折'] +from general_sam import GeneralSAM, GreedyTokenizer, build_trie_from_chars + +vocab = ['a', 'ab', 'b', 'bc', 'c', 'd', 'e', 'f', 'cd', 'abcde'] trie, token_to_trie_node = build_trie_from_chars(vocab) trie_node_to_token = [-1] * trie.num_of_nodes() @@ -208,12 +210,9 @@ tokenizer = GreedyTokenizer.from_sam_and_trie(sam, trie) def tokenize(s: str): return [(trie_node_to_token[i], j) for i, j in tokenizer.tokenize_str(s)] -assert tokenize('歌曲折') == [(0, 2), (-1, 1)] -assert tokenize('听歌曲') == [(5, 2), (-1, 1)] -assert tokenize('听歌曲折') == [(5, 2), (6, 2)] -assert tokenize('聆听歌曲折') == [(1, 4), (-1, 1)] -assert tokenize('查看歌词歌曲') == [(4, 4), (0, 2)] -assert tokenize('一起播放歌曲并共享歌词') == [(-1, 2), (2, 4), (-1, 3), (3, 2)] +assert tokenize('abcde') == [(9, 5)] +assert tokenize('abcdf') == [(1, 2), (8, 2), (7, 1)] +assert tokenize('abca') == [(1, 2), (4, 1), (0, 1)] ``` ## License diff --git a/general_sam/general_sam.pyi b/general_sam/general_sam.pyi index 4237009..aa3218b 100644 --- a/general_sam/general_sam.pyi +++ b/general_sam/general_sam.pyi @@ -39,7 +39,7 @@ class Trie: ) -> TrieNode: ... class GeneralSAMState: - def is_in_str(self) -> bool: ... + def is_in_chars(self) -> bool: ... def is_in_bytes(self) -> bool: ... def get_node_id(self) -> GeneralSAMNodeID: ... def is_nil(self) -> bool: ... @@ -79,7 +79,7 @@ class GeneralSAM: def from_bytes(s: bytes) -> 'GeneralSAM': ... @staticmethod def from_trie(trie: Trie) -> 'GeneralSAM': ... - def is_in_str(self) -> bool: ... + def is_in_chars(self) -> bool: ... def is_in_bytes(self) -> bool: ... def num_of_nodes(self) -> int: ... def get_root_state(self) -> GeneralSAMState: ... @@ -89,5 +89,8 @@ class GeneralSAM: class GreedyTokenizer: @staticmethod def from_sam_and_trie(sam: GeneralSAM, trie: Trie) -> 'GreedyTokenizer': ... + def get_sam(self) -> GeneralSAM: ... + def is_in_chars(self) -> bool: ... + def is_in_bytes(self) -> bool: ... def tokenize_str(self, s: str) -> Sequence[Tuple[TrieNodeID, int]]: ... def tokenize_bytes(self, s: bytes) -> Sequence[Tuple[TrieNodeID, int]]: ... diff --git a/src/tokenizer.rs b/src/tokenizer.rs index 5fcc58d..2722558 100644 --- a/src/tokenizer.rs +++ b/src/tokenizer.rs @@ -38,6 +38,18 @@ pub struct GreedyTokenizer(pub Arc); #[pymethods] impl GreedyTokenizer { + pub fn get_sam(&self) -> GeneralSAM { + GeneralSAM(self.0.borrow_sam().0.clone()) + } + + pub fn is_in_chars(&self) -> bool { + self.0.borrow_sam().is_in_chars() + } + + pub fn is_in_bytes(&self) -> bool { + self.0.borrow_sam().is_in_bytes() + } + #[staticmethod] pub fn from_sam_and_trie(sam: &GeneralSAM, trie: &Trie) -> PyResult { SharedGreedyTokenizer::from_sam_and_trie(sam, trie) diff --git a/tests/test_general_sam.py b/tests/test_general_sam.py index 25fae4b..01ee36e 100644 --- a/tests/test_general_sam.py +++ b/tests/test_general_sam.py @@ -3,6 +3,7 @@ def test_bytes_abcbc(): sam = GeneralSAM.from_bytes(b'abcbc') + assert sam.is_in_bytes() state = sam.get_root_state() state.feed_bytes(b'cbc') @@ -15,6 +16,8 @@ def test_bytes_abcbc(): def test_chars_abcbc(): sam = GeneralSAM.from_chars('abcbc') + assert sam.is_in_chars() + state = sam.get_root_state() state.feed_chars('b') @@ -30,6 +33,7 @@ def test_chars_abcbc(): def test_simple_sam_from_trie(): trie, _ = build_trie_from_chars(['hello', 'Chielo']) sam = GeneralSAM.from_trie(trie) + assert trie.is_in_chars() and sam.is_in_chars() def fetch_state(s: str) -> GeneralSAMState: state = sam.get_root_state() diff --git a/tests/test_greedy_tokenizer.py b/tests/test_greedy_tokenizer.py index a1ccb9f..f44b41f 100644 --- a/tests/test_greedy_tokenizer.py +++ b/tests/test_greedy_tokenizer.py @@ -1,4 +1,29 @@ -from general_sam import GeneralSAM, GreedyTokenizer, build_trie_from_chars +from general_sam import ( + GeneralSAM, + GreedyTokenizer, + build_trie_from_bytes, + build_trie_from_chars, +) + + +def test_english_chars_tokenize(): + vocab = ['a', 'ab', 'b', 'bc', 'c', 'd', 'e', 'f', 'cd', 'abcde'] + trie, token_to_trie_node = build_trie_from_chars(vocab) + + trie_node_to_token = [-1] * trie.num_of_nodes() + for i, j in enumerate(token_to_trie_node): + trie_node_to_token[j] = i + + sam = GeneralSAM.from_trie(trie) + tokenizer = GreedyTokenizer.from_sam_and_trie(sam, trie) + assert tokenizer.is_in_chars() + + def tokenize(s: str): + return [(trie_node_to_token[i], j) for i, j in tokenizer.tokenize_str(s)] + + assert tokenize('abcde') == [(9, 5)] + assert tokenize('abcdf') == [(1, 2), (8, 2), (7, 1)] + assert tokenize('abca') == [(1, 2), (4, 1), (0, 1)] def test_chinese_chars_tokenize(): @@ -11,6 +36,7 @@ def test_chinese_chars_tokenize(): sam = GeneralSAM.from_trie(trie) tokenizer = GreedyTokenizer.from_sam_and_trie(sam, trie) + assert tokenizer.is_in_chars() def tokenize(s: str): return [(trie_node_to_token[i], j) for i, j in tokenizer.tokenize_str(s)] @@ -21,3 +47,36 @@ def tokenize(s: str): assert tokenize('聆听歌曲折') == [(1, 4), (-1, 1)] assert tokenize('查看歌词歌曲') == [(4, 4), (0, 2)] assert tokenize('一起播放歌曲并共享歌词') == [(-1, 2), (2, 4), (-1, 3), (3, 2)] + + +def test_chinese_bytes_tokenize(): + vocab = ['歌曲', '聆听歌曲', '播放歌曲', '歌词', '查看歌词', '听歌', '曲折'] + vocab = [i.encode() for i in vocab] + trie, token_to_trie_node = build_trie_from_bytes(vocab) + + trie_node_to_token = [-1] * trie.num_of_nodes() + for i, j in enumerate(token_to_trie_node): + trie_node_to_token[j] = i + + sam = GeneralSAM.from_trie(trie) + tokenizer = GreedyTokenizer.from_sam_and_trie(sam, trie) + assert tokenizer.is_in_bytes() + + def tokenize_str(s: str): + return [trie_node_to_token[i] for i, _ in tokenizer.tokenize_str(s)] + + def tokenize_bytes(s: str): + return [trie_node_to_token[i] for i, _ in tokenizer.tokenize_bytes(s.encode())] + + def tokenize(s: str): + a = tokenize_str(s) + b = tokenize_bytes(s) + assert a == b + return a + + assert tokenize('歌曲折') == [0, -1] + assert tokenize('听歌曲') == [5, -1] + assert tokenize('听歌曲折') == [5, 6] + assert tokenize('聆听歌曲折') == [1, -1] + assert tokenize('查看歌词歌曲') == [4, 0] + assert tokenize('一起播放歌曲并共享歌词') == [-1, 2, -1, 3]