diff --git a/README.md b/README.md index ae8ab27..fe7673f 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,10 @@ See [`examples/rand-infer.rs`](examples/rand-infer.rs). echo '"def helloworl"' | cargo run --example rand-infer ``` +```sh +echo '"def helloworl"' | cargo run --example rand-infer +``` + ## TODOs - [ ] Python bindings diff --git a/examples/rand-infer.rs b/examples/rand-infer.rs index 224ac72..30efc1a 100644 --- a/examples/rand-infer.rs +++ b/examples/rand-infer.rs @@ -124,6 +124,15 @@ fn build_vocab>(tokenizer: T) -> Result>> { match parse_byte_repr(token) { Ok(byte) => token_bytes[id as usize].push(byte), Err(_) => { + if tokenizer + .get_added_vocabulary() + .get_added_tokens_decoder() + .contains_key(&id) + { + // ignore special tokens + continue; + } + let decoded = tokenizer .decode(&[dummy_token_id, id, dummy_token_id], false) .map_err(|e| eyre!(e))?; @@ -167,6 +176,24 @@ async fn main_body() -> Result<()> { .decode(tokenized.get_ids(), false) .map_err(|e| eyre!(e))?; + let offset = tokenized + .get_ids() + .iter() + .filter_map(|&id| { + tokenizer + .get_added_vocabulary() + .get_added_tokens_decoder() + .get(&id) + }) + .last() + .and_then(|special_token| { + println!("{special_token:?}"); + text.rfind(&special_token.content) + .map(|pos| pos + special_token.content.len()) + }) + .unwrap_or(0); + println!("search from pos {offset}\n"); + let Some((tree, mut req)) = SearchTree::new( automaton.clone(), |end_pos| async { @@ -178,7 +205,7 @@ async fn main_body() -> Result<()> { Ok::<_, tokenizers::Error>(res) }, text.as_str(), - 0, + offset, ) .await .map_err(|e| eyre!(e))?