Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: GGUF metadata tokenizer #389

Merged
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
1545395
tests: Use `cfg(test)` attribute to avoid `dead_code` warnings
polarathene Jun 4, 2024
ad6ca10
tests: DRY codec test cases
polarathene Jun 4, 2024
f3ba6d9
chore: Add `TODO` note regarding test remote data dependency
polarathene Jun 4, 2024
18f1567
refactor: DRY metadata extraction
polarathene Jun 4, 2024
c9651fa
refactor: Extract `unigram` tokenizer out of match statement
polarathene Jun 5, 2024
5417221
chore: `rustfmt` adjustments + notes
polarathene Jun 5, 2024
fe24df7
refactor: GGUF Unigram Tokenizer Vocab construction
polarathene Jun 5, 2024
0c78b31
Merge branch 'master' into refactor/gguf-metadata-tokenizer
polarathene Jun 5, 2024
ea4fd54
Update gguf_tokenizer.rs
polarathene Jun 5, 2024
fa70ffc
chore: Rename `MetadataContext` => `ContentMetadata`
polarathene Jun 6, 2024
bbe4d00
chore: `verify_sanity_gguf()` => `verify_arch()`
polarathene Jun 6, 2024
4ee563a
chore: Expand GGUF `Value` enum types support
polarathene Jun 6, 2024
ec16212
refactor: GGUF metadata - `quantized_llama.rs`
polarathene Jun 6, 2024
4cf25e5
refactor: GGUF metadata - `quantized_phi2.rs`
polarathene Jun 6, 2024
c4dfe68
refactor: GGUF metadata - `quantized_phi3.rs`
polarathene Jun 6, 2024
bbea097
refactor: GGUF metadata - X-LoRA llama + phi3
polarathene Jun 6, 2024
86f538c
tests: Skip encoder test case for special tokens
polarathene Jun 6, 2024
8bdc736
Update mistralrs-core/src/pipeline/gguf_tokenizer.rs
polarathene Jun 6, 2024
b3705c3
refactor: Use convenience enums for Decoder and Normalizer inputs
polarathene Jun 7, 2024
130b1ac
chore: Add a tokenizer builder workaround
polarathene Jun 7, 2024
dba3024
chore: `MetadataContent` path_prefix to `&str`
polarathene Jun 7, 2024
4b8d775
tests: Skip Decoder with special tokens
polarathene Jun 7, 2024
67e972f
fix: Decoder tests
polarathene Jun 7, 2024
74b3319
tests: Replace web request with hard-coded string
polarathene Jun 7, 2024
fe48b9c
docs: Add maintenance reference comment
polarathene Jun 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
278 changes: 140 additions & 138 deletions mistralrs-core/src/pipeline/gguf_tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use tokenizers::{
};
use tracing::info;

use crate::utils::gguf_metadata::MetadataContext;
use crate::DEBUG;

pub struct ConversionResult {
Expand All @@ -19,115 +20,71 @@ pub struct ConversionResult {
pub unk: Option<String>,
}

struct PropsGGUF {
model: String,
tokens: Vec<String>,
added_tokens: Option<Vec<String>>,
scores: Option<Vec<f32>>,
merges: Option<Vec<String>>,
unk: Option<u32>,
eos: u32,
bos: u32,
}

// This approach is a workaround for candles GGUF `Value` enum type wrapper,
// a better upstream approach would be to have serialize/deserialize support.
polarathene marked this conversation as resolved.
Show resolved Hide resolved
impl TryFrom<MetadataContext<'_>> for PropsGGUF {
type Error = anyhow::Error;

fn try_from(c: MetadataContext) -> Result<Self, Self::Error> {
let required = ["model", "tokens", "eos_token_id", "bos_token_id"];
c.has_required_keys(&required)?;

let tokenizer_ggml = PropsGGUF {
model: c.get_value("model")?,
tokens: c.get_value("tokens")?,
added_tokens: c.get_value("added_tokens").ok(),
scores: c.get_value("scores").ok(),
merges: c.get_value("merges").ok(),
unk: c.get_value("unknown_token_id").ok(),
eos: c.get_value("eos_token_id")?,
bos: c.get_value("bos_token_id")?,
};

Ok(tokenizer_ggml)
}
}

pub fn convert_ggml_to_hf_tokenizer(content: &Content) -> Result<ConversionResult> {
let model = content.metadata["tokenizer.ggml.model"]
.to_string()
.expect("GGUF tokenizer model is not a string.")
.clone();
let tokens = content.metadata["tokenizer.ggml.tokens"]
.to_vec()
.expect("GGUF tokenizer tokens is not a vec.")
.iter()
.map(|t| t.to_string().expect("GGUF token is not a string.").clone())
.collect::<Vec<_>>();
let added_tokens = content
.metadata
.get("tokenizer.ggml.added_tokens")
.map(|items| {
items
.to_vec()
.expect("GGUF tokenizer added_tokens is not a vec.")
.iter()
.map(|t| {
t.to_string()
.expect("GGUF added_token is not a string.")
.clone()
})
.collect::<Vec<_>>()
});
let scores = content.metadata.get("tokenizer.ggml.scores").map(|items| {
items
.to_vec()
.expect("GGUF tokenizer scores is not a vec.")
.iter()
.map(|t| t.to_f32().expect("GGUF score is not a f32."))
.collect::<Vec<_>>()
});
let merges = content.metadata.get("tokenizer.ggml.merges").map(|items| {
items
.to_vec()
.expect("GGUF tokenizer merges is not a vec.")
.iter()
.map(|t| t.to_string().expect("GGUF merges is not a string.").clone())
.collect::<Vec<_>>()
});

let unk = content
.metadata
.get("tokenizer.ggml.unknown_token_id")
.map(|t| t.to_u32().expect("GGUF unk token is not u32"));

let eos = content.metadata["tokenizer.ggml.eos_token_id"]
.to_u32()
.expect("GGUF unk token is not u32");

let bos = content.metadata["tokenizer.ggml.bos_token_id"]
.to_u32()
.expect("GGUF unk token is not u32");

let bos_str = tokens[bos as usize].clone();
let eos_str = tokens[eos as usize].clone();
let unk_str;

let (tokenizer, ty) = match model.as_str() {
"llama" | "replit" => {
// This is a `unigram` tokenizer
let scores = scores
.as_ref()
.expect("Expect `tokenizer.ggml.scores` for `llama` unigram tokeizer.");
let mut vocab = Vec::new();
for (token, score) in tokens.iter().zip(scores) {
vocab.push((token.clone(), *score as f64));
}
let metadata = MetadataContext {
path_prefix: "tokenizer.ggml".to_string(),
metadata: &content.metadata,
};
let props = PropsGGUF::try_from(metadata)?;

// Unigram (sentencepiece) default UNK is 0
let unk = unk.map(|x| x as usize).unwrap_or(0);
unk_str = tokens[unk].clone();

let unigram = Unigram::from(vocab, Some(unk), true).map_err(anyhow::Error::msg)?;
let mut tokenizer = Tokenizer::new(ModelWrapper::Unigram(unigram));
tokenizer.with_decoder(decoders::sequence::Sequence::new(vec![
DecoderWrapper::Replace(Replace::new("▁", " ").map_err(anyhow::Error::msg)?),
DecoderWrapper::ByteFallback(ByteFallback::new()),
DecoderWrapper::Fuse(Fuse::new()),
DecoderWrapper::Strip(Strip::new(' ', 1, 0)),
]));
tokenizer.with_normalizer(normalizers::Sequence::new(vec![
NormalizerWrapper::Prepend(Prepend::new("▁".to_string())),
NormalizerWrapper::Replace(Replace::new(" ", "▁").map_err(anyhow::Error::msg)?),
]));

tokenizer.add_special_tokens(&[AddedToken::from(tokens[bos as usize].clone(), true)]);
tokenizer.add_special_tokens(&[AddedToken::from(tokens[eos as usize].clone(), true)]);
tokenizer.add_special_tokens(&[AddedToken::from(tokens[unk].clone(), true)]);

(tokenizer, "unigram")
}
let (tokenizer, kind, special_tokens) = match props.model.as_str() {
"llama" | "replit" => unigram_tokenizer(&props)?,
other => {
anyhow::bail!("Tokenizer model `{other}` not supported.");
}
};

info!(
"GGUF tokenizer model is `{model}`, kind: `{}`, num tokens: {}, num added tokens: {}, num merges: {}, num scores: {}",
ty,
"GGUF tokenizer model is `{model}`, kind: `{kind:?}`, num tokens: {}, num added tokens: {}, num merges: {}, num scores: {}",
tokenizer.get_vocab_size(true),
added_tokens.as_ref().map(|x| x.len()).unwrap_or(0),
merges.as_ref().map(|x| x.len()).unwrap_or(0),
scores.as_ref().map(|x| x.len()).unwrap_or(0)
props.added_tokens.as_ref().map(|x| x.len()).unwrap_or(0),
props.merges.as_ref().map(|x| x.len()).unwrap_or(0),
props.scores.as_ref().map(|x| x.len()).unwrap_or(0),
model = props.model,
);
if DEBUG.load(Ordering::Relaxed) {
info!("Tokenizer: {tokenizer:?}");
}

let [bos_str, eos_str, unk_str] = special_tokens
EricLBuehler marked this conversation as resolved.
Show resolved Hide resolved
.try_into()
.or_else(|_| anyhow::bail!("Tokenizer is missing required special tokens"))?;

Ok(ConversionResult {
tokenizer,
bos: Some(bos_str),
Expand All @@ -136,6 +93,54 @@ pub fn convert_ggml_to_hf_tokenizer(content: &Content) -> Result<ConversionResul
})
}

// TODO: Add support for additional tokenizer models: BPE, WordPiece, WordLevel
// https://docs.rs/tokenizers/latest/tokenizers/models/enum.ModelWrapper.html
#[derive(Debug)]
enum TokenizerKind {
Unigram,
}

fn unigram_tokenizer(p: &PropsGGUF) -> Result<(Tokenizer, TokenizerKind, Vec<String>)> {
let PropsGGUF { unk, eos, bos, .. } = *p;
// Unigram (SentencePiece) default UNK is 0
let unk = unk.unwrap_or(0);

let vocab: Vec<(String, f64)> = {
let Some(s) = p.scores.as_ref() else {
anyhow::bail!(
"`llama` unigram tokenizer is missing required metadata `tokenizer.ggml.scores`"
);
};
let scores = s.iter().cloned().map(|f_32| f_32 as f64);

p.tokens.iter().cloned().zip(scores).collect()
};

let unigram = Unigram::from(vocab, Some(unk as usize), true).map_err(anyhow::Error::msg)?;
let mut tokenizer = Tokenizer::new(ModelWrapper::Unigram(unigram));
tokenizer.with_decoder(decoders::sequence::Sequence::new(vec![
DecoderWrapper::Replace(Replace::new("▁", " ").map_err(anyhow::Error::msg)?),
DecoderWrapper::ByteFallback(ByteFallback::new()),
DecoderWrapper::Fuse(Fuse::new()),
DecoderWrapper::Strip(Strip::new(' ', 1, 0)),
]));
tokenizer.with_normalizer(normalizers::Sequence::new(vec![
NormalizerWrapper::Prepend(Prepend::new("▁".to_string())),
NormalizerWrapper::Replace(Replace::new(" ", "▁").map_err(anyhow::Error::msg)?),
]));
polarathene marked this conversation as resolved.
Show resolved Hide resolved

let mut special_tokens = Vec::<String>::new();
for token_id in [bos, eos, unk] {
let token = p.tokens[token_id as usize].as_str();

special_tokens.push(token.to_owned());
tokenizer.add_special_tokens(&[AddedToken::from(token.to_owned(), true)]);
}

Ok((tokenizer, TokenizerKind::Unigram, special_tokens))
}

#[cfg(test)]
mod tests {
use anyhow::Result;
use candle_core::quantized::gguf_file::Content;
Expand All @@ -154,7 +159,6 @@ mod tests {
Rwkv,
}

#[allow(dead_code)]
fn get_gguf_tokenizer(tokenizer: TokenizerType) -> Result<Tokenizer> {
match tokenizer {
TokenizerType::Llama => {
Expand All @@ -179,7 +183,6 @@ mod tests {
}
}

#[allow(dead_code)]
fn get_hf_tokenizer(tokenizer: TokenizerType) -> Result<Tokenizer> {
match tokenizer {
TokenizerType::Llama => {
Expand All @@ -197,50 +200,56 @@ mod tests {
}
}

#[allow(dead_code)]
fn get_test_passage() -> String {
// TODO: Why is it necessary to depend on this for a multi-line test string?
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we can hardcode this? It should be easy.

let passage = reqwest::blocking::get("https://loripsum.net/api")
.expect("Failed to download sample text")
.bytes()
.expect("Failed to get bytes");

String::from_utf8(passage.to_vec()).expect("Failed to convert sample text to string.")
}

// The provided passage should encode and decode back into the same passage string:
fn codec_roundtrip(
tokenizer: &Tokenizer,
passage: &str,
add_special_tokens: bool,
) -> Result<String> {
let tokenized = tokenizer
.encode(passage, add_special_tokens)
.map_err(anyhow::Error::msg)?;

// NOTE: The special tokens bool param meaning differs between encode() / decode():
decode(tokenizer, tokenized.get_ids(), !add_special_tokens)
}

fn decode(
tokenizer: &Tokenizer,
token_ids: &[u32],
skip_special_tokens: bool,
) -> Result<String> {
tokenizer
.decode(token_ids, skip_special_tokens)
.map_err(anyhow::Error::msg)
}

#[test]
fn test_encode_llama() -> Result<()> {
let passage = get_test_passage();
let hf_tokenizer = get_hf_tokenizer(TokenizerType::Llama)?;
let gguf_tokenizer = get_gguf_tokenizer(TokenizerType::Llama)?;

// Without special tokens
let hf_tokenized = hf_tokenizer
.encode(passage.as_str(), false)
.map_err(anyhow::Error::msg)?;
let gguf_tokenized = gguf_tokenizer
.encode(passage.as_str(), false)
.map_err(anyhow::Error::msg)?;
let hf_decoded = hf_tokenizer
.decode(hf_tokenized.get_ids(), false)
.map_err(anyhow::Error::msg)?;
let gguf_decoded = gguf_tokenizer
.decode(gguf_tokenized.get_ids(), false)
.map_err(anyhow::Error::msg)?;
// Without adding special tokens
let hf_decoded = codec_roundtrip(&hf_tokenizer, passage.as_str(), false)?;
let gguf_decoded = codec_roundtrip(&gguf_tokenizer, passage.as_str(), false)?;
assert_eq!(hf_decoded, gguf_decoded);

// With special tokens
let hf_tokenized = hf_tokenizer
.encode(passage.as_str(), true)
.map_err(anyhow::Error::msg)?;
let gguf_tokenized = gguf_tokenizer
.encode(passage.as_str(), true)
.map_err(anyhow::Error::msg)?;
let hf_decoded = hf_tokenizer
.decode(hf_tokenized.get_ids(), true)
.map_err(anyhow::Error::msg)?;
let gguf_decoded = gguf_tokenizer
.decode(gguf_tokenized.get_ids(), true)
.map_err(anyhow::Error::msg)?;
// With special tokens added
let hf_decoded = codec_roundtrip(&hf_tokenizer, passage.as_str(), true)?;
let gguf_decoded = codec_roundtrip(&gguf_tokenizer, passage.as_str(), true)?;
assert_eq!(hf_decoded, gguf_decoded);

Ok(())
}

Expand All @@ -257,22 +266,15 @@ mod tests {
tokens.shuffle(&mut thread_rng());

// Without skipping special tokens
let hf_decoded = hf_tokenizer
.decode(&tokens, false)
.map_err(anyhow::Error::msg)?;
let gguf_decoded = gguf_tokenizer
.decode(&tokens, false)
.map_err(anyhow::Error::msg)?;
let hf_decoded = decode(&hf_tokenizer, &tokens, false)?;
let gguf_decoded = decode(&gguf_tokenizer, &tokens, false)?;
assert_eq!(hf_decoded, gguf_decoded);

// With skipping special tokens
let hf_decoded = hf_tokenizer
.decode(&tokens, true)
.map_err(anyhow::Error::msg)?;
let gguf_decoded = gguf_tokenizer
.decode(&tokens, true)
.map_err(anyhow::Error::msg)?;
let hf_decoded = decode(&hf_tokenizer, &tokens, true)?;
let gguf_decoded = decode(&gguf_tokenizer, &tokens, true)?;
assert_eq!(hf_decoded, gguf_decoded);

Ok(())
}
}
Loading
Loading