Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: GGUF metadata tokenizer #389

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
1545395
tests: Use `cfg(test)` attribute to avoid `dead_code` warnings
polarathene Jun 4, 2024
ad6ca10
tests: DRY codec test cases
polarathene Jun 4, 2024
f3ba6d9
chore: Add `TODO` note regarding test remote data dependency
polarathene Jun 4, 2024
18f1567
refactor: DRY metadata extraction
polarathene Jun 4, 2024
c9651fa
refactor: Extract `unigram` tokenizer out of match statement
polarathene Jun 5, 2024
5417221
chore: `rustfmt` adjustments + notes
polarathene Jun 5, 2024
fe24df7
refactor: GGUF Unigram Tokenizer Vocab construction
polarathene Jun 5, 2024
0c78b31
Merge branch 'master' into refactor/gguf-metadata-tokenizer
polarathene Jun 5, 2024
ea4fd54
Update gguf_tokenizer.rs
polarathene Jun 5, 2024
fa70ffc
chore: Rename `MetadataContext` => `ContentMetadata`
polarathene Jun 6, 2024
bbe4d00
chore: `verify_sanity_gguf()` => `verify_arch()`
polarathene Jun 6, 2024
4ee563a
chore: Expand GGUF `Value` enum types support
polarathene Jun 6, 2024
ec16212
refactor: GGUF metadata - `quantized_llama.rs`
polarathene Jun 6, 2024
4cf25e5
refactor: GGUF metadata - `quantized_phi2.rs`
polarathene Jun 6, 2024
c4dfe68
refactor: GGUF metadata - `quantized_phi3.rs`
polarathene Jun 6, 2024
bbea097
refactor: GGUF metadata - X-LoRA llama + phi3
polarathene Jun 6, 2024
86f538c
tests: Skip encoder test case for special tokens
polarathene Jun 6, 2024
8bdc736
Update mistralrs-core/src/pipeline/gguf_tokenizer.rs
polarathene Jun 6, 2024
b3705c3
refactor: Use convenience enums for Decoder and Normalizer inputs
polarathene Jun 7, 2024
130b1ac
chore: Add a tokenizer builder workaround
polarathene Jun 7, 2024
dba3024
chore: `MetadataContent` path_prefix to `&str`
polarathene Jun 7, 2024
4b8d775
tests: Skip Decoder with special tokens
polarathene Jun 7, 2024
67e972f
fix: Decoder tests
polarathene Jun 7, 2024
74b3319
tests: Replace web request with hard-coded string
polarathene Jun 7, 2024
fe48b9c
docs: Add maintenance reference comment
polarathene Jun 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mistralrs-core/src/layers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use candle_nn::{Linear, Module, VarBuilder};
use either::Either;

pub use crate::layers_masker::CausalMasker;
pub use crate::layers_utils::{flash_attn, repeat_kv, verify_sanity_gguf};
pub use crate::layers_utils::{flash_attn, repeat_kv};

use crate::{cublaslt::CUBLASLT_HANDLE, INHIBIT_GEMM_F16};

Expand Down
7 changes: 0 additions & 7 deletions mistralrs-core/src/layers_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,6 @@ pub fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result
unimplemented!("Compile with '--features flash-attn'")
}

pub fn verify_sanity_gguf(arch: &str, expected_arch: &str) -> Result<()> {
if arch != expected_arch {
candle_core::bail!("Expected `{expected_arch}` architecture, got `{arch}`.");
}
Ok(())
}

pub fn repeat_kv(x: Tensor, n_rep: usize) -> Result<Tensor> {
if n_rep == 1 {
Ok(x)
Expand Down
105 changes: 73 additions & 32 deletions mistralrs-core/src/models/quantized_llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@ use candle_core::{DType, Device, Result, Tensor};
use candle_nn::{Embedding, Module, RotaryEmbedding};

use crate::device_map::DeviceMapper;
use crate::layers::{
repeat_kv, verify_sanity_gguf, CausalMasker, MatMul, QRmsNorm, ScaledDotProductAttention,
};
use crate::layers::{repeat_kv, CausalMasker, MatMul, QRmsNorm, ScaledDotProductAttention};
use crate::pipeline::{extract_logits, Cache};
use crate::utils::max_seq_len::get_gguf_max_seq_len;
use crate::utils::gguf_metadata::ContentMetadata;
use crate::utils::model_config as ModelConfig;
use crate::DeviceMapMetadata;

Expand Down Expand Up @@ -258,43 +256,86 @@ impl ModelConfig::FromGGML for ModelWeights {
}
}

// llama `llm` fields:
// https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#llm
// NOTE: Types here do not match spec
pub(crate) struct PropsGGUF {
pub n_expert: usize,
pub n_expert_used: usize,
pub head_count: usize,
pub head_count_kv: usize,
pub block_count: usize,
pub embedding_length: usize,
pub rope_dim: usize,
pub rms_norm_eps: f32,
pub max_seq_len: usize,
pub rope_freq_base: f32,
}

impl TryFrom<ContentMetadata<'_>> for PropsGGUF {
type Error = anyhow::Error;

fn try_from(c: ContentMetadata) -> std::result::Result<Self, Self::Error> {
c.verify_arch("llama")?;

let required = [
"attention.head_count",
"attention.head_count_kv",
"block_count",
"embedding_length",
"rope.dimension_count",
"attention.layer_norm_rms_epsilon",
];
c.has_required_keys(&required)?;

// NOTE: Values are not aligned with GGUFv3 types
// TODO: Normalize value types to spec
let props = Self {
n_expert: c.get_value::<u32>("expert_count").ok().unwrap_or(0) as usize,
n_expert_used: c.get_value::<u32>("expert_used_count").ok().unwrap_or(0) as usize,
head_count: c.get_value::<u32>("attention.head_count")? as usize,
head_count_kv: c.get_value::<u32>("attention.head_count_kv")? as usize,
block_count: c.get_value::<u32>("block_count")? as usize,
embedding_length: c.get_value::<u32>("embedding_length")? as usize,
rope_dim: c.get_value::<u32>("rope.dimension_count")? as usize,
// Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default.
rms_norm_eps: c.get_value("attention.layer_norm_rms_epsilon")?,
max_seq_len: c
.get_value::<u64>("context_length")
.ok()
.unwrap_or(MAX_SEQ_LEN as u64) as usize,
rope_freq_base: c.get_value("rope.freq_base").ok().unwrap_or(10_000_f32),
};

Ok(props)
}
}

impl ModelConfig::FromGGUF for ModelWeights {
fn from_gguf<R: std::io::Seek + std::io::Read>(
ct: gguf_file::Content,
reader: &mut R,
device: &Device,
mapper: DeviceMapMetadata,
) -> Result<Self> {
let md_get = |s: &str| match ct.metadata.get(s) {
None => candle_core::bail!("cannot find {s} in metadata"),
Some(v) => Ok(v),
// Parameter extraction from metadata.
let metadata = ContentMetadata {
path_prefix: "llama",
metadata: &ct.metadata,
};
verify_sanity_gguf(
md_get("general.architecture")?.to_string().unwrap(),
"llama",
)?;
let PropsGGUF {
n_expert,
n_expert_used,
head_count,
head_count_kv,
block_count,
embedding_length,
rope_dim,
rms_norm_eps,
max_seq_len,
rope_freq_base,
} = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?;

// Parameter extraction from metadata.
let n_expert = md_get("llama.expert_count")
.and_then(|v| v.to_u32())
.unwrap_or(0) as usize;
let n_expert_used = md_get("llama.expert_used_count")
.and_then(|v| v.to_u32())
.unwrap_or(0) as usize;
let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize;
let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize;
let block_count = md_get("llama.block_count")?.to_u32()? as usize;
let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize;
let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize;
// Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default.
let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()?;

let max_seq_len =
get_gguf_max_seq_len(md_get("llama.context_length"), MAX_SEQ_LEN as u64) as usize;

let rope_freq_base = md_get("llama.rope.freq_base")
.and_then(|m| m.to_f32())
.unwrap_or(10000f32);
let head_dim = embedding_length / head_count;
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
let tok_embeddings = tok_embeddings.dequantize(device)?;
Expand Down
77 changes: 63 additions & 14 deletions mistralrs-core/src/models/quantized_phi2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::device_map::DeviceMapper;
use crate::layers::ScaledDotProductAttention;
use crate::layers::{repeat_kv, CausalMasker, QLinear};
use crate::pipeline::{extract_logits, Cache};
use crate::utils::max_seq_len::get_gguf_max_seq_len;
use crate::utils::gguf_metadata::ContentMetadata;
use crate::utils::model_config as ModelConfig;
use crate::DeviceMapMetadata;

Expand Down Expand Up @@ -143,27 +143,76 @@ fn layer_norm(w: QTensor, b: QTensor, eps: f64) -> Result<LayerNorm> {
Ok(ln)
}

// phi2 `llm` fields:
// https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#llm
// NOTE: Types here do not match spec
struct PropsGGUF {
head_count: usize,
head_count_kv: usize,
block_count: usize,
embedding_length: usize,
rope_dim: usize,
ln_eps: f64,
max_seq_len: usize,
}

impl TryFrom<ContentMetadata<'_>> for PropsGGUF {
type Error = anyhow::Error;

fn try_from(c: ContentMetadata) -> std::result::Result<Self, Self::Error> {
c.verify_arch("phi2")?;

let required = [
"attention.head_count",
"attention.head_count_kv",
"block_count",
"embedding_length",
"rope.dimension_count",
"attention.layer_norm_rms_epsilon",
"context_length",
];
c.has_required_keys(&required)?;

// NOTE: Values are not aligned with GGUFv3 types
// TODO: Normalize value types to spec
let props = Self {
head_count: c.get_value::<u32>("attention.head_count")? as usize,
head_count_kv: c.get_value::<u32>("attention.head_count_kv")? as usize,
block_count: c.get_value::<u32>("block_count")? as usize,
embedding_length: c.get_value::<u32>("embedding_length")? as usize,
rope_dim: c.get_value::<u32>("rope.dimension_count")? as usize,
ln_eps: c.get_value::<f32>("attention.layer_norm_rms_epsilon")? as f64,
max_seq_len: c
.get_value::<u64>("context_length")
.ok()
.unwrap_or(MAX_SEQ_LEN as u64) as usize,
polarathene marked this conversation as resolved.
Show resolved Hide resolved
};

Ok(props)
}
}

impl ModelConfig::FromGGUF for ModelWeights {
fn from_gguf<R: std::io::Seek + std::io::Read>(
ct: gguf_file::Content,
reader: &mut R,
device: &Device,
mapper: DeviceMapMetadata,
) -> Result<Self> {
let md_get = |s: &str| match ct.metadata.get(s) {
None => candle_core::bail!("cannot find {s} in metadata"),
Some(v) => Ok(v),
};

// Parameter extraction from metadata.
let head_count = md_get("phi2.attention.head_count")?.to_u32()? as usize;
let head_count_kv = md_get("phi2.attention.head_count_kv")?.to_u32()? as usize;
let block_count = md_get("phi2.block_count")?.to_u32()? as usize;
let embedding_length = md_get("phi2.embedding_length")?.to_u32()? as usize;
let rope_dim = md_get("phi2.rope.dimension_count")?.to_u32()? as usize;
let ln_eps = md_get("phi2.attention.layer_norm_epsilon")?.to_f32()? as f64;
let max_seq_len =
get_gguf_max_seq_len(md_get("phi2.context_length"), MAX_SEQ_LEN as u64) as usize;
let metadata = ContentMetadata {
path_prefix: "phi2",
metadata: &ct.metadata,
};
let PropsGGUF {
head_count,
head_count_kv,
block_count,
embedding_length,
rope_dim,
ln_eps,
max_seq_len,
} = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?;

let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device, max_seq_len)?;

Expand Down
81 changes: 65 additions & 16 deletions mistralrs-core/src/models/quantized_phi3.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]

use crate::device_map::DeviceMapper;
use crate::layers::{
repeat_kv, verify_sanity_gguf, CausalMasker, MatMul, RmsNorm, ScaledDotProductAttention,
};
use crate::layers::{repeat_kv, CausalMasker, MatMul, RmsNorm, ScaledDotProductAttention};
use crate::pipeline::Cache;
use crate::utils::gguf_metadata::ContentMetadata;
use crate::utils::model_config as ModelConfig;
use crate::DeviceMapMetadata;
use candle_core::quantized::gguf_file;
Expand Down Expand Up @@ -160,28 +159,78 @@ fn precomput_freqs_cis(
Ok((cos, sin))
}

// phi3 `llm` fields:
// https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#llm
// NOTE: Types here do not match spec
pub(crate) struct PropsGGUF {
pub head_count: usize,
pub head_count_kv: usize,
pub block_count: usize,
pub embedding_length: usize,
pub i_size: usize,
pub rope_dim: usize,
pub rms_eps: f64,
pub context_window: usize,
}

impl TryFrom<ContentMetadata<'_>> for PropsGGUF {
type Error = anyhow::Error;

fn try_from(c: ContentMetadata) -> std::result::Result<Self, Self::Error> {
c.verify_arch("phi3")?;

let required = [
"attention.head_count",
"attention.head_count_kv",
"block_count",
"embedding_length",
"feed_forward_length",
"rope.dimension_count",
"attention.layer_norm_rms_epsilon",
"context_length",
];
c.has_required_keys(&required)?;
EricLBuehler marked this conversation as resolved.
Show resolved Hide resolved

// NOTE: Values are not aligned with GGUFv3 types
// TODO: Normalize value types to spec
let props = Self {
head_count: c.get_value::<u32>("attention.head_count")? as usize,
head_count_kv: c.get_value::<u32>("attention.head_count_kv")? as usize,
block_count: c.get_value::<u32>("block_count")? as usize,
embedding_length: c.get_value::<u32>("embedding_length")? as usize,
i_size: c.get_value::<u32>("feed_forward_length")? as usize,
rope_dim: c.get_value::<u32>("rope.dimension_count")? as usize,
rms_eps: c.get_value::<f32>("attention.layer_norm_rms_epsilon")? as f64,
context_window: c.get_value::<u32>("context_length")? as usize,
};

Ok(props)
}
}

impl ModelConfig::FromGGUF for ModelWeights {
fn from_gguf<R: std::io::Seek + std::io::Read>(
ct: gguf_file::Content,
reader: &mut R,
device: &Device,
mapper: DeviceMapMetadata,
) -> Result<Self> {
let md_get = |s: &str| match ct.metadata.get(s) {
None => candle_core::bail!("cannot find {s} in metadata"),
Some(v) => Ok(v),
// Parameter extraction from metadata.
let metadata = ContentMetadata {
path_prefix: "phi3",
metadata: &ct.metadata,
};
verify_sanity_gguf(md_get("general.architecture")?.to_string().unwrap(), "phi3")?;
let PropsGGUF {
head_count,
head_count_kv,
block_count,
embedding_length,
i_size,
rope_dim,
rms_eps,
context_window,
} = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?;

// Parameter extraction from metadata.
let head_count = md_get("phi3.attention.head_count")?.to_u32()? as usize;
let head_count_kv = md_get("phi3.attention.head_count_kv")?.to_u32()? as usize;
let block_count = md_get("phi3.block_count")?.to_u32()? as usize;
let embedding_length = md_get("phi3.embedding_length")?.to_u32()? as usize;
let i_size = md_get("phi3.feed_forward_length")?.to_u32()? as usize;
let rope_dim = md_get("phi3.rope.dimension_count")?.to_u32()? as usize;
let rms_eps = md_get("phi3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
let context_window = md_get("phi3.context_length")?.to_u32()? as usize;
let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device, context_window)?;

let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
Expand Down
Loading
Loading