Skip to content

Commit

Permalink
Refactor: GGUF metadata tokenizer (#389)
Browse files Browse the repository at this point in the history
* tests: Use `cfg(test)` attribute to avoid `dead_code` warnings

Proper way to opt-out of the dead code warnings is annotate the test module as purely for testing.

* tests: DRY codec test cases

Remove the repetitive noise to common functions under test.

This also addresses a bug fix for the encode test case where the upstream encoder/decoder calls flip the meaning of the `bool` for handling special tokens.

* chore: Add `TODO` note regarding test remote data dependency

* refactor: DRY metadata extraction

Retrieving metadata items from their hashmap `Value` enum into primitive types with error handling is very verbose and noisy.

Use traits to abstract all that away. This could also benefit usage in models `from_gguf()` methods.

Meanwhile the unigram tokenizer has unified the special token handling at the end by keeping `unk` as a `u32` and only casting it to `usize` when actually needed.

* refactor: Extract `unigram` tokenizer out of match statement

The special token strings are also being created in the tokenizer now. A bit awkward, but unclear why only `unk` was an option, presumably the `bos` and `eos` may also need similar treatment to `unk`?

* chore: `rustfmt` adjustments + notes

* refactor: GGUF Unigram Tokenizer Vocab construction

* Update gguf_tokenizer.rs

* chore: Rename `MetadataContext` => `ContentMetadata`

* chore: `verify_sanity_gguf()` => `verify_arch()`

This is a partial change, the method will be changed over in subsequent commits.

* chore: Expand GGUF `Value` enum types support

For the quantized models to leverage.

Additionally changes helper methods over to `anyhow::Error`.

* refactor: GGUF metadata - `quantized_llama.rs`

* refactor: GGUF metadata - `quantized_phi2.rs`

* refactor: GGUF metadata - `quantized_phi3.rs`

* refactor: GGUF metadata - X-LoRA llama + phi3

- `get_gguf_max_seq_len` dropped as not compatible with current approach.
- Non-XLora models export their `PropsGGUF` as the share the exact same code as their X-LoRA versions for this metadata.
- Switch from `eprintln!` to `warn!` for required props check.
- Additional cleanup.

* tests: Skip encoder test case for special tokens

When the encoder adds special tokens (`true`), it also runs the decoder with `false` to not skip processing special tokens.

There is a mismatch in the output between HF and GGUF tokenizers during the decode, where the GGUF is missing an initial `<s> `.

Advice is to skip this test case for now.

* Update mistralrs-core/src/pipeline/gguf_tokenizer.rs

* refactor: Use convenience enums for Decoder and Normalizer inputs

These two enum types are similar to their equivalent upstream `*Wrapper` types, except instead of wrapping individual structs, they take a tuple of any args and use a `TryFrom` impl to recreate the actual types (`new()` + any error handling) and then convert to the wrapped enum variant.

Shifts away that noise and inconsistent API away so that the tokenizer methods are easier to grok.

* chore: Add a tokenizer builder workaround

Similar to the enum workaround. As the upstream builder is awkward to use, an alternative one is implemented to improve the DX.

The enum conversion to upstream wrapper types is handled in the builder now, simplifying usage in a tokenizer method.

* chore: `MetadataContent` path_prefix to `&str`

* tests: Skip Decoder with special tokens

This test fails presently. It is due to the mismatch of the HF tokenizer vs GGUF tokenizer used.

* fix: Decoder tests

This special character looks very similar to `_` but it is not. This was a mistake I introduced when converting to local enums approach.

* tests: Replace web request with hard-coded string

* docs: Add maintenance reference comment

Added context for this particular configuration.
  • Loading branch information
polarathene authored Jun 7, 2024
1 parent 44e8a22 commit 8b2d092
Show file tree
Hide file tree
Showing 11 changed files with 619 additions and 280 deletions.
2 changes: 1 addition & 1 deletion mistralrs-core/src/layers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use candle_nn::{Linear, Module, VarBuilder};
use either::Either;

pub use crate::layers_masker::CausalMasker;
pub use crate::layers_utils::{flash_attn, repeat_kv, verify_sanity_gguf};
pub use crate::layers_utils::{flash_attn, repeat_kv};

use crate::{cublaslt::CUBLASLT_HANDLE, INHIBIT_GEMM_F16};

Expand Down
7 changes: 0 additions & 7 deletions mistralrs-core/src/layers_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,6 @@ pub fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result
unimplemented!("Compile with '--features flash-attn'")
}

pub fn verify_sanity_gguf(arch: &str, expected_arch: &str) -> Result<()> {
if arch != expected_arch {
candle_core::bail!("Expected `{expected_arch}` architecture, got `{arch}`.");
}
Ok(())
}

pub fn repeat_kv(x: Tensor, n_rep: usize) -> Result<Tensor> {
if n_rep == 1 {
Ok(x)
Expand Down
105 changes: 73 additions & 32 deletions mistralrs-core/src/models/quantized_llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@ use candle_core::{DType, Device, Result, Tensor};
use candle_nn::{Embedding, Module, RotaryEmbedding};

use crate::device_map::DeviceMapper;
use crate::layers::{
repeat_kv, verify_sanity_gguf, CausalMasker, MatMul, QRmsNorm, ScaledDotProductAttention,
};
use crate::layers::{repeat_kv, CausalMasker, MatMul, QRmsNorm, ScaledDotProductAttention};
use crate::pipeline::{extract_logits, Cache};
use crate::utils::max_seq_len::get_gguf_max_seq_len;
use crate::utils::gguf_metadata::ContentMetadata;
use crate::utils::model_config as ModelConfig;
use crate::DeviceMapMetadata;

Expand Down Expand Up @@ -258,43 +256,86 @@ impl ModelConfig::FromGGML for ModelWeights {
}
}

// llama `llm` fields:
// https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#llm
// NOTE: Types here do not match spec
pub(crate) struct PropsGGUF {
pub n_expert: usize,
pub n_expert_used: usize,
pub head_count: usize,
pub head_count_kv: usize,
pub block_count: usize,
pub embedding_length: usize,
pub rope_dim: usize,
pub rms_norm_eps: f32,
pub max_seq_len: usize,
pub rope_freq_base: f32,
}

impl TryFrom<ContentMetadata<'_>> for PropsGGUF {
type Error = anyhow::Error;

fn try_from(c: ContentMetadata) -> std::result::Result<Self, Self::Error> {
c.verify_arch("llama")?;

let required = [
"attention.head_count",
"attention.head_count_kv",
"block_count",
"embedding_length",
"rope.dimension_count",
"attention.layer_norm_rms_epsilon",
];
c.has_required_keys(&required)?;

// NOTE: Values are not aligned with GGUFv3 types
// TODO: Normalize value types to spec
let props = Self {
n_expert: c.get_value::<u32>("expert_count").ok().unwrap_or(0) as usize,
n_expert_used: c.get_value::<u32>("expert_used_count").ok().unwrap_or(0) as usize,
head_count: c.get_value::<u32>("attention.head_count")? as usize,
head_count_kv: c.get_value::<u32>("attention.head_count_kv")? as usize,
block_count: c.get_value::<u32>("block_count")? as usize,
embedding_length: c.get_value::<u32>("embedding_length")? as usize,
rope_dim: c.get_value::<u32>("rope.dimension_count")? as usize,
// Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default.
rms_norm_eps: c.get_value("attention.layer_norm_rms_epsilon")?,
max_seq_len: c
.get_value::<u64>("context_length")
.ok()
.unwrap_or(MAX_SEQ_LEN as u64) as usize,
rope_freq_base: c.get_value("rope.freq_base").ok().unwrap_or(10_000_f32),
};

Ok(props)
}
}

impl ModelConfig::FromGGUF for ModelWeights {
fn from_gguf<R: std::io::Seek + std::io::Read>(
ct: gguf_file::Content,
reader: &mut R,
device: &Device,
mapper: DeviceMapMetadata,
) -> Result<Self> {
let md_get = |s: &str| match ct.metadata.get(s) {
None => candle_core::bail!("cannot find {s} in metadata"),
Some(v) => Ok(v),
// Parameter extraction from metadata.
let metadata = ContentMetadata {
path_prefix: "llama",
metadata: &ct.metadata,
};
verify_sanity_gguf(
md_get("general.architecture")?.to_string().unwrap(),
"llama",
)?;
let PropsGGUF {
n_expert,
n_expert_used,
head_count,
head_count_kv,
block_count,
embedding_length,
rope_dim,
rms_norm_eps,
max_seq_len,
rope_freq_base,
} = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?;

// Parameter extraction from metadata.
let n_expert = md_get("llama.expert_count")
.and_then(|v| v.to_u32())
.unwrap_or(0) as usize;
let n_expert_used = md_get("llama.expert_used_count")
.and_then(|v| v.to_u32())
.unwrap_or(0) as usize;
let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize;
let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize;
let block_count = md_get("llama.block_count")?.to_u32()? as usize;
let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize;
let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize;
// Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default.
let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()?;

let max_seq_len =
get_gguf_max_seq_len(md_get("llama.context_length"), MAX_SEQ_LEN as u64) as usize;

let rope_freq_base = md_get("llama.rope.freq_base")
.and_then(|m| m.to_f32())
.unwrap_or(10000f32);
let head_dim = embedding_length / head_count;
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
let tok_embeddings = tok_embeddings.dequantize(device)?;
Expand Down
77 changes: 63 additions & 14 deletions mistralrs-core/src/models/quantized_phi2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::device_map::DeviceMapper;
use crate::layers::ScaledDotProductAttention;
use crate::layers::{repeat_kv, CausalMasker, QLinear};
use crate::pipeline::{extract_logits, Cache};
use crate::utils::max_seq_len::get_gguf_max_seq_len;
use crate::utils::gguf_metadata::ContentMetadata;
use crate::utils::model_config as ModelConfig;
use crate::DeviceMapMetadata;

Expand Down Expand Up @@ -143,27 +143,76 @@ fn layer_norm(w: QTensor, b: QTensor, eps: f64) -> Result<LayerNorm> {
Ok(ln)
}

// phi2 `llm` fields:
// https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#llm
// NOTE: Types here do not match spec
struct PropsGGUF {
head_count: usize,
head_count_kv: usize,
block_count: usize,
embedding_length: usize,
rope_dim: usize,
ln_eps: f64,
max_seq_len: usize,
}

impl TryFrom<ContentMetadata<'_>> for PropsGGUF {
type Error = anyhow::Error;

fn try_from(c: ContentMetadata) -> std::result::Result<Self, Self::Error> {
c.verify_arch("phi2")?;

let required = [
"attention.head_count",
"attention.head_count_kv",
"block_count",
"embedding_length",
"rope.dimension_count",
"attention.layer_norm_rms_epsilon",
"context_length",
];
c.has_required_keys(&required)?;

// NOTE: Values are not aligned with GGUFv3 types
// TODO: Normalize value types to spec
let props = Self {
head_count: c.get_value::<u32>("attention.head_count")? as usize,
head_count_kv: c.get_value::<u32>("attention.head_count_kv")? as usize,
block_count: c.get_value::<u32>("block_count")? as usize,
embedding_length: c.get_value::<u32>("embedding_length")? as usize,
rope_dim: c.get_value::<u32>("rope.dimension_count")? as usize,
ln_eps: c.get_value::<f32>("attention.layer_norm_rms_epsilon")? as f64,
max_seq_len: c
.get_value::<u64>("context_length")
.ok()
.unwrap_or(MAX_SEQ_LEN as u64) as usize,
};

Ok(props)
}
}

impl ModelConfig::FromGGUF for ModelWeights {
fn from_gguf<R: std::io::Seek + std::io::Read>(
ct: gguf_file::Content,
reader: &mut R,
device: &Device,
mapper: DeviceMapMetadata,
) -> Result<Self> {
let md_get = |s: &str| match ct.metadata.get(s) {
None => candle_core::bail!("cannot find {s} in metadata"),
Some(v) => Ok(v),
};

// Parameter extraction from metadata.
let head_count = md_get("phi2.attention.head_count")?.to_u32()? as usize;
let head_count_kv = md_get("phi2.attention.head_count_kv")?.to_u32()? as usize;
let block_count = md_get("phi2.block_count")?.to_u32()? as usize;
let embedding_length = md_get("phi2.embedding_length")?.to_u32()? as usize;
let rope_dim = md_get("phi2.rope.dimension_count")?.to_u32()? as usize;
let ln_eps = md_get("phi2.attention.layer_norm_epsilon")?.to_f32()? as f64;
let max_seq_len =
get_gguf_max_seq_len(md_get("phi2.context_length"), MAX_SEQ_LEN as u64) as usize;
let metadata = ContentMetadata {
path_prefix: "phi2",
metadata: &ct.metadata,
};
let PropsGGUF {
head_count,
head_count_kv,
block_count,
embedding_length,
rope_dim,
ln_eps,
max_seq_len,
} = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?;

let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device, max_seq_len)?;

Expand Down
81 changes: 65 additions & 16 deletions mistralrs-core/src/models/quantized_phi3.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]

use crate::device_map::DeviceMapper;
use crate::layers::{
repeat_kv, verify_sanity_gguf, CausalMasker, MatMul, RmsNorm, ScaledDotProductAttention,
};
use crate::layers::{repeat_kv, CausalMasker, MatMul, RmsNorm, ScaledDotProductAttention};
use crate::pipeline::Cache;
use crate::utils::gguf_metadata::ContentMetadata;
use crate::utils::model_config as ModelConfig;
use crate::DeviceMapMetadata;
use candle_core::quantized::gguf_file;
Expand Down Expand Up @@ -160,28 +159,78 @@ fn precomput_freqs_cis(
Ok((cos, sin))
}

// phi3 `llm` fields:
// https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#llm
// NOTE: Types here do not match spec
pub(crate) struct PropsGGUF {
pub head_count: usize,
pub head_count_kv: usize,
pub block_count: usize,
pub embedding_length: usize,
pub i_size: usize,
pub rope_dim: usize,
pub rms_eps: f64,
pub context_window: usize,
}

impl TryFrom<ContentMetadata<'_>> for PropsGGUF {
type Error = anyhow::Error;

fn try_from(c: ContentMetadata) -> std::result::Result<Self, Self::Error> {
c.verify_arch("phi3")?;

let required = [
"attention.head_count",
"attention.head_count_kv",
"block_count",
"embedding_length",
"feed_forward_length",
"rope.dimension_count",
"attention.layer_norm_rms_epsilon",
"context_length",
];
c.has_required_keys(&required)?;

// NOTE: Values are not aligned with GGUFv3 types
// TODO: Normalize value types to spec
let props = Self {
head_count: c.get_value::<u32>("attention.head_count")? as usize,
head_count_kv: c.get_value::<u32>("attention.head_count_kv")? as usize,
block_count: c.get_value::<u32>("block_count")? as usize,
embedding_length: c.get_value::<u32>("embedding_length")? as usize,
i_size: c.get_value::<u32>("feed_forward_length")? as usize,
rope_dim: c.get_value::<u32>("rope.dimension_count")? as usize,
rms_eps: c.get_value::<f32>("attention.layer_norm_rms_epsilon")? as f64,
context_window: c.get_value::<u32>("context_length")? as usize,
};

Ok(props)
}
}

impl ModelConfig::FromGGUF for ModelWeights {
fn from_gguf<R: std::io::Seek + std::io::Read>(
ct: gguf_file::Content,
reader: &mut R,
device: &Device,
mapper: DeviceMapMetadata,
) -> Result<Self> {
let md_get = |s: &str| match ct.metadata.get(s) {
None => candle_core::bail!("cannot find {s} in metadata"),
Some(v) => Ok(v),
// Parameter extraction from metadata.
let metadata = ContentMetadata {
path_prefix: "phi3",
metadata: &ct.metadata,
};
verify_sanity_gguf(md_get("general.architecture")?.to_string().unwrap(), "phi3")?;
let PropsGGUF {
head_count,
head_count_kv,
block_count,
embedding_length,
i_size,
rope_dim,
rms_eps,
context_window,
} = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?;

// Parameter extraction from metadata.
let head_count = md_get("phi3.attention.head_count")?.to_u32()? as usize;
let head_count_kv = md_get("phi3.attention.head_count_kv")?.to_u32()? as usize;
let block_count = md_get("phi3.block_count")?.to_u32()? as usize;
let embedding_length = md_get("phi3.embedding_length")?.to_u32()? as usize;
let i_size = md_get("phi3.feed_forward_length")?.to_u32()? as usize;
let rope_dim = md_get("phi3.rope.dimension_count")?.to_u32()? as usize;
let rms_eps = md_get("phi3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
let context_window = md_get("phi3.context_length")?.to_u32()? as usize;
let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device, context_window)?;

let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
Expand Down
Loading

0 comments on commit 8b2d092

Please sign in to comment.