From 8805b40a99eeef2a7ad5a649e91d2361c6072ca1 Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Tue, 13 Aug 2024 14:25:55 +0800 Subject: [PATCH] Support in-situ quantization (#77) * Support in-situ quantization * Typo fix * Cargo fmt --- README.md | 40 +++-- src/lib.rs | 92 ++++++++++- src/openai/models/gemma.rs | 73 +++++++-- src/openai/models/linear.rs | 263 ++++++++++++++++++++++++++++++- src/openai/models/llama.rs | 54 +++++-- src/openai/models/mistral.rs | 67 ++++++-- src/openai/models/mod.rs | 2 + src/openai/models/phi2.rs | 60 +++++-- src/openai/models/phi3.rs | 46 +++++- src/openai/models/qwen2.rs | 63 ++++++-- src/openai/models/stable_lm.rs | 69 ++++++-- src/openai/models/yi.rs | 67 ++++++-- src/openai/pipelines/pipeline.rs | 49 ++---- tests/tests.rs | 1 + 14 files changed, 810 insertions(+), 136 deletions(-) diff --git a/README.md b/README.md index 251bd97..25a9ccd 100644 --- a/README.md +++ b/README.md @@ -12,25 +12,26 @@ Efficient, easy-to-use platform for inference and serving local LLMs including a - Streaming support in generation. - Efficient management of key-value cache with PagedAttention. - Continuous batching. +- In-situ quantization ## Develop Status Currently, candle-vllm supports chat serving for the following models. -| Model ID | Model Type | Supported | Speed (A100, BF16) | Throughput (bs=16) -|--|--|--|--|--| -| #1 | **LLAMA/LLAMA2/LLaMa3/LLaMa3.1** |✅|74 tks/s (7B), 65 tks/s (LLaMa3.1 8B)| 553 tks/s (LLaMa3.1 8B) | -| #2 | **Mistral** |✅|70 tks/s (7B)| 585 tks/s (7B) | -| #3 | **Phi (v1, v1.5, v2)** |✅|97 tks/s (2.7B, F32+BF16)|TBD| -| #4 | **Phi-3 (3.8B, 7B)** |✅|107 tks/s (3.8B)| 744 tks/s (3.8B)| -| #5 | **Yi** |✅|75 tks/s (6B)| 566 tks/s (6B) | -| #6 | **StableLM** |✅|99 tks/s (3B)|TBD| -| #7 | BigCode/StarCode |TBD|TBD|TBD | -| #8 | ChatGLM |TBD|TBD|TBD | -| #9 | **QWen2 (1.8B, 7B)** |✅|148 tks/s (1.8B)|784 tks/s (1.8B) | -| #10 | **Google Gemma** |✅|130 tks/s (2B)|TBD | -| #11 | Blip-large (Multimodal) |TBD|TBD|TBD | -| #12 | Moondream-2 (Multimodal LLM) |TBD|TBD|TBD | +| Model ID | Model Type | Supported | Speed (A100, BF16) | Throughput (bs=16) | Quantized (A100, Q8_0) | +|--|--|--|--|--|--| +| #1 | **LLAMA/LLAMA2/LLaMa3/LLaMa3.1** |✅|74 tks/s (7B), 65 tks/s (LLaMa3.1 8B)| 553 tks/s (LLaMa3.1 8B) | 65 tks/s (LLaMa3.1 8B) | +| #2 | **Mistral** |✅|70 tks/s (7B)| 585 tks/s (7B) | 78 tks/s (7B) | +| #3 | **Phi (v1, v1.5, v2)** |✅|97 tks/s (2.7B, F32+BF16)|TBD|-| +| #4 | **Phi-3 (3.8B, 7B)** |✅|107 tks/s (3.8B)| 744 tks/s (3.8B)|116 tks/s (3.8B)| +| #5 | **Yi** |✅|75 tks/s (6B)| 566 tks/s (6B) | 79 tks/s (6B)| +| #6 | **StableLM** |✅|99 tks/s (3B)|TBD|-| +| #7 | BigCode/StarCode |TBD|TBD|TBD |-| +| #8 | ChatGLM |TBD|TBD|TBD |-| +| #9 | **QWen2 (1.8B, 7B)** |✅|148 tks/s (1.8B)|784 tks/s (1.8B) |-| +| #10 | **Google Gemma** |✅|130 tks/s (2B)|TBD |-| +| #11 | Blip-large (Multimodal) |TBD|TBD|TBD |-| +| #12 | Moondream-2 (Multimodal LLM) |TBD|TBD|TBD |-| ## Demo Chat with candle-vllm (61-65 tokens/s, LLaMa3.1 8B, bf16, on A100) @@ -187,6 +188,17 @@ async def benchmark(): asyncio.run(benchmark()) ``` +## In-situ quantization for consumer-grade GPUs + +Candle-vllm now supports in-situ quantization, allowing the transformation of default weights (F32/F16/BF16) into any GGML format during model loading. This feature helps conserve GPU memory, making it more efficient for consumer-grade GPUs (e.g., RTX 4090). For example, 8-bit quantization can reduce memory usage to less than 20GB for 8B models, while 4-bit quantization can bring it down to under 22GB for 13B models. To use this feature, simply supply the quant parameter when running candle-vllm. + +``` +cargo run --release -- --port 2000 --weight-path /home/Meta-Llama-3.1-8B-Instruct/ llama3 --quant q8_0 +``` + +Options for `quant` parameters: ["q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "q2k", "q3k","q4k","q5k","q6k"] + +**Please note** that batched processing still requires optimization when operating in quantization mode. ## Usage Help For general configuration help, run `cargo run -- --help`. diff --git a/src/lib.rs b/src/lib.rs index f53d074..0ef5d21 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,10 +2,7 @@ use candle::Result; use candle_core as candle; use clap::Subcommand; -use openai::pipelines::{ - pipeline::{DefaultLoader, SpecificConfig}, - ModelLoader, -}; +use openai::pipelines::{pipeline::DefaultLoader, ModelLoader}; #[derive(Debug, Subcommand)] pub enum ModelSelected { @@ -23,6 +20,9 @@ pub enum ModelSelected { #[arg(long)] max_gen_tokens: Option, + + #[arg(long)] + quant: Option, }, /// Select the llama3 model (default llama3.1-8b). @@ -39,6 +39,9 @@ pub enum ModelSelected { #[arg(long)] max_gen_tokens: Option, + + #[arg(long)] + quant: Option, }, /// Select the phi2 model (default 2.7b). @@ -55,6 +58,9 @@ pub enum ModelSelected { #[arg(long)] max_gen_tokens: Option, + + #[arg(long)] + quant: Option, }, /// Select the phi3 model (default 3.8b). @@ -77,6 +83,9 @@ pub enum ModelSelected { #[arg(long)] max_gen_tokens: Option, + + #[arg(long)] + quant: Option, }, /// Select the qwen model (default 1.8b). @@ -99,6 +108,9 @@ pub enum ModelSelected { #[arg(long)] max_gen_tokens: Option, + + #[arg(long)] + quant: Option, }, /// Select the gemma model (default 2b). @@ -115,6 +127,9 @@ pub enum ModelSelected { #[arg(long)] max_gen_tokens: Option, + + #[arg(long)] + quant: Option, }, /// Select the mistral model (default 7b). @@ -131,6 +146,9 @@ pub enum ModelSelected { #[arg(long)] max_gen_tokens: Option, + + #[arg(long)] + quant: Option, }, /// Select the Yi model (default 6b). @@ -147,6 +165,9 @@ pub enum ModelSelected { #[arg(long)] max_gen_tokens: Option, + + #[arg(long)] + quant: Option, }, /// Select the stable-lm model (default zephyr-3b). @@ -163,6 +184,9 @@ pub enum ModelSelected { #[arg(long)] max_gen_tokens: Option, + + #[arg(long)] + quant: Option, }, } @@ -174,18 +198,21 @@ impl ToString for ModelSelected { temperature: _, penalty: _, max_gen_tokens: _, + quant: _, } => "llama".to_string(), ModelSelected::Llama3 { repeat_last_n: _, temperature: _, penalty: _, max_gen_tokens: _, + quant: _, } => "llama3".to_string(), ModelSelected::Phi2 { repeat_last_n: _, temperature: _, penalty: _, max_gen_tokens: _, + quant: _, } => "phi2".to_string(), ModelSelected::Phi3 { repeat_last_n: _, @@ -194,6 +221,7 @@ impl ToString for ModelSelected { top_p: _, penalty: _, max_gen_tokens: _, + quant: _, } => "phi3".to_string(), ModelSelected::Qwen2 { repeat_last_n: _, @@ -202,35 +230,73 @@ impl ToString for ModelSelected { top_p: _, penalty: _, max_gen_tokens: _, + quant: _, } => "qwen2".to_string(), ModelSelected::Gemma { repeat_last_n: _, temperature: _, penalty: _, max_gen_tokens: _, + quant: _, } => "gemma".to_string(), ModelSelected::Mistral { repeat_last_n: _, temperature: _, penalty: _, max_gen_tokens: _, + quant: _, } => "mistral".to_string(), ModelSelected::Yi { repeat_last_n: _, temperature: _, penalty: _, max_gen_tokens: _, + quant: _, } => "yi".to_string(), ModelSelected::StableLM { repeat_last_n: _, temperature: _, penalty: _, max_gen_tokens: _, + quant: _, } => "stablelm".to_string(), } } } +#[derive(Debug, Clone)] +pub struct SpecificConfig { + repeat_last_n: Option, + temperature: Option, + top_k: Option, + top_p: Option, + penalty: Option, + max_gen_tokens: Option, + quant: Option, +} + +impl SpecificConfig { + pub fn new( + repeat_last_n: Option, + temperature: Option, + top_k: Option, + top_p: Option, + penalty: Option, + max_gen_tokens: Option, + quant: Option, + ) -> Self { + Self { + repeat_last_n, + temperature, + top_k, + top_p, + penalty, + max_gen_tokens, + quant, + } + } +} + pub fn get_model_loader( selected_model: ModelSelected, model_id: Option, @@ -241,6 +307,7 @@ pub fn get_model_loader( temperature, penalty, max_gen_tokens, + quant, } => ( Box::new(DefaultLoader::new( SpecificConfig::new( @@ -250,6 +317,7 @@ pub fn get_model_loader( None, penalty, max_gen_tokens, + quant, ), "llama".to_string(), )), @@ -264,6 +332,7 @@ pub fn get_model_loader( temperature, penalty, max_gen_tokens, + quant, } => ( Box::new(DefaultLoader::new( SpecificConfig::new( @@ -273,6 +342,7 @@ pub fn get_model_loader( None, penalty, max_gen_tokens, + quant, ), "llama3".to_string(), )), @@ -287,6 +357,7 @@ pub fn get_model_loader( temperature, penalty, max_gen_tokens, + quant, } => ( Box::new(DefaultLoader::new( SpecificConfig::new( @@ -296,6 +367,7 @@ pub fn get_model_loader( None, penalty, max_gen_tokens, + quant, ), "phi2".to_string(), )), @@ -312,6 +384,7 @@ pub fn get_model_loader( top_p, penalty, max_gen_tokens, + quant, } => ( Box::new(DefaultLoader::new( SpecificConfig::new( @@ -321,6 +394,7 @@ pub fn get_model_loader( top_p, penalty, max_gen_tokens, + quant, ), "phi3".to_string(), )), @@ -337,6 +411,7 @@ pub fn get_model_loader( top_p, penalty, max_gen_tokens, + quant, } => ( Box::new(DefaultLoader::new( SpecificConfig::new( @@ -346,6 +421,7 @@ pub fn get_model_loader( top_p, penalty, max_gen_tokens, + quant, ), "qwen2".to_string(), )), @@ -360,6 +436,7 @@ pub fn get_model_loader( temperature, penalty, max_gen_tokens, + quant, } => ( Box::new(DefaultLoader::new( SpecificConfig::new( @@ -369,6 +446,7 @@ pub fn get_model_loader( None, penalty, max_gen_tokens, + quant, ), "gemma".to_string(), )), @@ -383,6 +461,7 @@ pub fn get_model_loader( temperature, penalty, max_gen_tokens, + quant, } => ( Box::new(DefaultLoader::new( SpecificConfig::new( @@ -392,6 +471,7 @@ pub fn get_model_loader( None, penalty, max_gen_tokens, + quant, ), "mistral".to_string(), )), @@ -407,6 +487,7 @@ pub fn get_model_loader( temperature, penalty, max_gen_tokens, + quant, } => ( Box::new(DefaultLoader::new( SpecificConfig::new( @@ -416,6 +497,7 @@ pub fn get_model_loader( None, penalty, max_gen_tokens, + quant, ), "yi".to_string(), )), @@ -431,6 +513,7 @@ pub fn get_model_loader( temperature, penalty, max_gen_tokens, + quant, } => ( Box::new(DefaultLoader::new( SpecificConfig::new( @@ -440,6 +523,7 @@ pub fn get_model_loader( None, penalty, max_gen_tokens, + quant, ), "stablelm".to_string(), )), diff --git a/src/openai/models/gemma.rs b/src/openai/models/gemma.rs index af29b8a..0a8c6e3 100644 --- a/src/openai/models/gemma.rs +++ b/src/openai/models/gemma.rs @@ -1,12 +1,14 @@ use super::Config; -use crate::openai::models::linear::{linear_b, linear_no_bias as linear, Linear}; +use crate::openai::models::linear::{ + linear_b_x as linear_b, linear_no_bias_x as linear, LinearX as Linear, +}; use crate::paged_attention::input_metadata::InputMetadata; use crate::paged_attention::PagedAttention; +use crate::SpecificConfig; use candle::{DType, Device, IndexOp, Module, Result, Tensor}; use candle_core as candle; use candle_nn::Activation; use candle_nn::{RmsNorm, VarBuilder}; - use either::Either; use std::iter::zip; use std::sync::Arc; @@ -31,7 +33,12 @@ pub struct GemmaConfig { } impl GemmaConfig { - pub fn into_config(self, use_flash_attn: bool, kv_cache_dtype: DType) -> Config { + pub fn into_config( + self, + use_flash_attn: bool, + kv_cache_dtype: DType, + scfg: &SpecificConfig, + ) -> Config { let hidden_act = match (self.hidden_act, self.hidden_activation) { (None, Some(act)) | (Some(act), None) => Some(act), (Some(_), Some(_)) => panic!("both hidden_act and hidden_activation are set"), @@ -61,6 +68,7 @@ impl GemmaConfig { kv_cache_dtype, use_qkv_bias: None, custom_stop_tokens: None, + specific_config: scfg.clone(), } } } @@ -135,9 +143,24 @@ impl MLP { fn new(cfg: &Config, vb: VarBuilder) -> Result { let hidden_sz = cfg.hidden_size; let intermediate_sz = cfg.intermediate_size; - let gate_proj = linear(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?; - let up_proj = linear(hidden_sz, intermediate_sz, vb.pp("up_proj"))?; - let down_proj = linear(intermediate_sz, hidden_sz, vb.pp("down_proj"))?; + let gate_proj = linear( + hidden_sz, + intermediate_sz, + vb.pp("gate_proj"), + &cfg.specific_config.quant, + )?; + let up_proj = linear( + hidden_sz, + intermediate_sz, + vb.pp("up_proj"), + &cfg.specific_config.quant, + )?; + let down_proj = linear( + intermediate_sz, + hidden_sz, + vb.pp("down_proj"), + &cfg.specific_config.quant, + )?; Ok(Self { gate_proj, up_proj, @@ -175,10 +198,34 @@ impl Attention { let num_kv_heads = cfg.num_key_value_heads; let head_dim = cfg.hidden_size / cfg.num_attention_heads; let bias = cfg.attention_bias; - let q_proj = linear_b(hidden_sz, num_heads * head_dim, bias, vb.pp("q_proj"))?; - let k_proj = linear_b(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("k_proj"))?; - let v_proj = linear_b(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("v_proj"))?; - let o_proj = linear_b(num_heads * head_dim, hidden_sz, bias, vb.pp("o_proj"))?; + let q_proj = linear_b( + hidden_sz, + num_heads * head_dim, + bias, + vb.pp("q_proj"), + &cfg.specific_config.quant, + )?; + let k_proj = linear_b( + hidden_sz, + num_kv_heads * head_dim, + bias, + vb.pp("k_proj"), + &cfg.specific_config.quant, + )?; + let v_proj = linear_b( + hidden_sz, + num_kv_heads * head_dim, + bias, + vb.pp("v_proj"), + &cfg.specific_config.quant, + )?; + let o_proj = linear_b( + num_heads * head_dim, + hidden_sz, + bias, + vb.pp("o_proj"), + &cfg.specific_config.quant, + )?; Ok(Self { q_proj, k_proj, @@ -340,7 +387,11 @@ impl Gemma { layers.push(layer) } let norm = rms_norm(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?; - let lm_head = Linear::new(embed_tokens.embeddings().clone(), None); + let lm_head = Linear::new( + embed_tokens.embeddings().clone(), + None, + &cfg.specific_config.quant, + ); Ok(Self { embed_tokens, layers, diff --git a/src/openai/models/linear.rs b/src/openai/models/linear.rs index ca330dc..61f3001 100644 --- a/src/openai/models/linear.rs +++ b/src/openai/models/linear.rs @@ -18,8 +18,15 @@ //! # Ok(()) } //! ``` use crate::candle::Module; -use crate::candle::{Result, Tensor}; +use crate::candle::{ + quantized::{gguf_file, QMatMul, QTensor}, + DType, Device, Result, Tensor, +}; +use candle_core::quantized; use candle_nn::init; +use either::Either; +use std::sync::Arc; + #[derive(Clone, Debug)] pub struct Linear { weight: Tensor, @@ -126,3 +133,257 @@ pub fn linear_b( linear_no_bias(in_dim, out_dim, vb) } } + +#[derive(Debug, Clone)] +pub struct QLinear { + inner: QMatMul, + bias: Option, + dtype: DType, +} + +impl QLinear { + pub fn new( + ct: &gguf_file::Content, + r: &mut R, + name: &str, + device: &Device, + ) -> Result { + let w = ct.tensor(r, &format!("{name}.weight"), device)?; + let b = ct.tensor(r, &format!("{name}.bias"), device)?; + let inner = QMatMul::from_qtensor(w)?; + let bias = b.dequantize(device)?; + Ok(Self { + inner, + bias: Some(bias), + dtype: DType::F32, + }) + } + + pub fn from_linear(linear: Linear) -> Self { + Self { + inner: QMatMul::Tensor(linear.weight().clone()), + bias: linear.bias().cloned(), + dtype: linear.weight().dtype(), + } + } + + pub fn from_parts(w: Tensor, b: Option) -> Self { + let dtype = w.dtype(); + Self { + inner: QMatMul::Tensor(w), + bias: b, + dtype, + } + } + + pub fn from_qparts(w: QTensor, b: Option) -> Self { + if let Some(ref b) = b { + assert_eq!(b.dtype(), DType::F32); + } + Self { + inner: QMatMul::QTensor(Arc::new(w)), + bias: b, + dtype: DType::F32, + } + } + + pub fn from_qparts_x(w: QTensor, b: Option, dtype: DType) -> Self { + let bx = match b { + Some(b_) => { + if b_.dtype() != DType::F32 { + Some(b_.to_dtype(DType::F32).unwrap()) + } else { + Some(b_) + } + } + _ => None, + }; + + Self { + inner: QMatMul::QTensor(Arc::new(w)), + bias: bx, + dtype: dtype, + } + } + + pub fn from_linear_x(linear: Linear, quant: String) -> Self { + let weight = linear.weight(); + let dtype = weight.dtype(); + use quantized::GgmlDType; + + let ggml_dtype = match quant.as_str() { + "q4_0" => GgmlDType::Q4_0, + "q4_1" => GgmlDType::Q4_1, + "q5_0" => GgmlDType::Q5_0, + "q5_1" => GgmlDType::Q5_1, + "q8_0" => GgmlDType::Q8_0, + "q2k" => GgmlDType::Q2K, + "q3k" => GgmlDType::Q3K, + "q4k" => GgmlDType::Q4K, + "q5k" => GgmlDType::Q5K, + "q6k" => GgmlDType::Q6K, + _ => panic!("Unsupported GGML data type!"), + }; + let qtensor = QTensor::quantize(weight, ggml_dtype).unwrap(); + let qbias = match linear.bias() { + Some(b) => Some(b.clone()), + _ => None, + }; + + QLinear::from_qparts_x(qtensor, qbias, dtype) + } + + pub fn from_old_and_qmatmul(inner: QMatMul, old: &Self) -> Self { + Self { + inner, + bias: old.bias.clone(), + dtype: old.dtype, + } + } + + pub fn inner(&mut self) -> &mut QMatMul { + &mut self.inner + } + + pub fn inner_ref(&self) -> &QMatMul { + &self.inner + } + + pub fn is_quant(&self) -> bool { + matches!(self.inner, QMatMul::QTensor(_)) + } + + pub fn bias(&self) -> Option<&Tensor> { + self.bias.as_ref() + } + + pub fn bias_mut(&mut self) -> Option<&mut Tensor> { + self.bias.as_mut() + } +} + +impl Module for QLinear { + fn forward(&self, x: &Tensor) -> Result { + let xs = if self.is_quant() { + let x1 = match *x.dims() { + [bsize, seq_len, dim1, dim2] => { + if seq_len > 1 { + x.to_dtype(DType::F32)? + } else { + x.reshape((bsize, dim1, dim2))?.to_dtype(DType::F32)? + } + } + [bsize, seq_len, dim] => { + if seq_len > 1 { + x.to_dtype(DType::F32)? + } else { + x.reshape((bsize, dim))?.to_dtype(DType::F32)? + } + } + _ => x.to_dtype(DType::F32)?, + }; + x1 + } else { + x.clone() + }; + + let xs = match *x.dims() { + [bsize, seq_len, dim1, _] => { + if seq_len > 1 { + QMatMul::forward(&self.inner, &xs)? + } else { + QMatMul::forward(&self.inner, &xs)?.reshape((bsize, seq_len, dim1, ()))? + } + } + [bsize, seq_len, _] => { + if seq_len > 1 { + QMatMul::forward(&self.inner, &xs)? + } else { + QMatMul::forward(&self.inner, &xs)?.reshape((bsize, seq_len, ()))? + } + } + _ => QMatMul::forward(&self.inner, &xs)?, + }; + + if let Some(bias) = &self.bias { + xs.broadcast_add(bias)?.to_dtype(self.dtype) + } else { + xs.to_dtype(self.dtype) + } + } +} + +#[derive(Debug, Clone)] +pub struct LinearX(Either); + +impl Module for LinearX { + fn forward(&self, x: &Tensor) -> Result { + match &self.0 { + Either::Left(ln) => ln.forward(x), + Either::Right(ln) => ln.forward(x), + } + } +} +impl LinearX { + pub fn new(weight: Tensor, bias: Option, quant: &Option) -> Self { + let ln = Linear::new(weight, bias); + if let Some(quatized_type) = quant { + LinearX(Either::Right(QLinear::from_linear_x( + ln, + quatized_type.clone(), + ))) + } else { + LinearX(Either::Left(ln)) + } + } +} + +pub fn linear_x( + in_dim: usize, + out_dim: usize, + vb: candle_nn::VarBuilder, + quant: &Option, +) -> Result { + let ln = linear(in_dim, out_dim, vb).unwrap(); + if let Some(quatized_type) = quant { + Ok(LinearX(Either::Right(QLinear::from_linear_x( + ln, + quatized_type.clone(), + )))) + } else { + Ok(LinearX(Either::Left(ln))) + } +} + +pub fn linear_no_bias_x( + in_dim: usize, + out_dim: usize, + vb: candle_nn::VarBuilder, + quant: &Option, +) -> Result { + let init_ws = init::DEFAULT_KAIMING_NORMAL; + let ws = vb.get_with_hints((out_dim, in_dim), "weight", init_ws)?; + let ln = Linear::new(ws, None); + if let Some(quatized_type) = quant { + Ok(LinearX(Either::Right(QLinear::from_linear_x( + ln, + quatized_type.clone(), + )))) + } else { + Ok(LinearX(Either::Left(ln))) + } +} + +pub fn linear_b_x( + in_dim: usize, + out_dim: usize, + bias: bool, + vb: candle_nn::VarBuilder, + quant: &Option, +) -> Result { + if bias { + linear_x(in_dim, out_dim, vb, quant) + } else { + linear_no_bias_x(in_dim, out_dim, vb, quant) + } +} diff --git a/src/openai/models/llama.rs b/src/openai/models/llama.rs index 64f5ccb..168139f 100644 --- a/src/openai/models/llama.rs +++ b/src/openai/models/llama.rs @@ -1,15 +1,16 @@ use super::Config; -use crate::openai::models::linear::{linear_no_bias as linear, Linear}; +use crate::openai::models::linear::{linear_no_bias_x as linear, LinearX as Linear}; use crate::paged_attention::input_metadata::InputMetadata; use crate::paged_attention::PagedAttention; +use crate::SpecificConfig; use candle::{DType, Device, IndexOp, Result, Tensor}; use candle_core as candle; use candle_nn::{embedding, Embedding, Module, VarBuilder}; use candle_transformers::models::with_tracing::RmsNorm; - pub const MAX_SEQ_LEN: usize = 4096; use crate::openai::models::TokenID; use std::iter::zip; + #[derive(Debug, Clone, serde::Deserialize)] pub struct LlamaConfig { pub hidden_size: usize, @@ -31,7 +32,12 @@ fn default_rope() -> f32 { } impl LlamaConfig { - pub fn into_config(self, use_flash_attn: bool, kv_cache_dtype: DType) -> Config { + pub fn into_config( + self, + use_flash_attn: bool, + kv_cache_dtype: DType, + scfg: &SpecificConfig, + ) -> Config { Config { hidden_size: self.hidden_size, intermediate_size: self.intermediate_size, @@ -56,6 +62,7 @@ impl LlamaConfig { kv_cache_dtype, use_qkv_bias: None, custom_stop_tokens: None, + specific_config: scfg.clone(), } } } @@ -184,10 +191,20 @@ impl CausalSelfAttention { let size_in = cfg.hidden_size; let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads; let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads; - let q_proj = linear(size_in, size_q, vb.pp("q_proj"))?; - let k_proj = linear(size_in, size_kv, vb.pp("k_proj"))?; - let v_proj = linear(size_in, size_kv, vb.pp("v_proj"))?; - let o_proj = linear(size_q, size_in, vb.pp("o_proj"))?; + let q_proj = linear(size_in, size_q, vb.pp("q_proj"), &cfg.specific_config.quant)?; + let k_proj = linear( + size_in, + size_kv, + vb.pp("k_proj"), + &cfg.specific_config.quant, + )?; + let v_proj = linear( + size_in, + size_kv, + vb.pp("v_proj"), + &cfg.specific_config.quant, + )?; + let o_proj = linear(size_q, size_in, vb.pp("o_proj"), &cfg.specific_config.quant)?; let head_dim = cfg.hidden_size / cfg.num_attention_heads; Ok(Self { q_proj, @@ -232,9 +249,19 @@ impl Mlp { let span = tracing::span!(tracing::Level::TRACE, "mlp"); let h_size = cfg.hidden_size; let i_size = cfg.intermediate_size; - let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?; - let c_fc2 = linear(h_size, i_size, vb.pp("up_proj"))?; - let c_proj = linear(i_size, h_size, vb.pp("down_proj"))?; + let c_fc1 = linear( + h_size, + i_size, + vb.pp("gate_proj"), + &cfg.specific_config.quant, + )?; + let c_fc2 = linear(h_size, i_size, vb.pp("up_proj"), &cfg.specific_config.quant)?; + let c_proj = linear( + i_size, + h_size, + vb.pp("down_proj"), + &cfg.specific_config.quant, + )?; Ok(Self { c_fc1, c_fc2, @@ -358,7 +385,12 @@ impl Llama { pub fn load(vb: VarBuilder, cfg: &Config, dtype: DType, device: &Device) -> Result { let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?; - let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; + let lm_head = linear( + cfg.hidden_size, + cfg.vocab_size, + vb.pp("lm_head"), + &cfg.specific_config.quant, + )?; let ln_f = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?; let blocks: Vec<_> = (0..cfg.num_hidden_layers) .map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cfg, dtype, device).unwrap()) diff --git a/src/openai/models/mistral.rs b/src/openai/models/mistral.rs index 80cf71b..038e267 100644 --- a/src/openai/models/mistral.rs +++ b/src/openai/models/mistral.rs @@ -1,7 +1,8 @@ use super::Config; -use crate::openai::models::linear::{linear_no_bias, Linear}; +use crate::openai::models::linear::{linear_no_bias_x as linear_no_bias, LinearX as Linear}; use crate::paged_attention::input_metadata::InputMetadata; use crate::paged_attention::PagedAttention; +use crate::SpecificConfig; use candle_core::{DType, Device, IndexOp, Module, Result, Tensor}; use candle_nn::{Activation, VarBuilder}; use candle_transformers::models::with_tracing::RmsNorm; @@ -28,7 +29,12 @@ pub struct MistralConfig { } impl MistralConfig { - pub fn into_config(self, use_flash_attn: bool, kv_cache_dtype: DType) -> Config { + pub fn into_config( + self, + use_flash_attn: bool, + kv_cache_dtype: DType, + scfg: &SpecificConfig, + ) -> Config { Config { hidden_size: self.hidden_size, intermediate_size: self.intermediate_size, @@ -53,6 +59,7 @@ impl MistralConfig { kv_cache_dtype, use_qkv_bias: None, custom_stop_tokens: None, + specific_config: scfg.clone(), } } } @@ -124,9 +131,24 @@ impl MLP { fn new(cfg: &Config, vb: VarBuilder) -> Result { let hidden_sz = cfg.hidden_size; let intermediate_sz = cfg.intermediate_size; - let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?; - let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?; - let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?; + let gate_proj = linear_no_bias( + hidden_sz, + intermediate_sz, + vb.pp("gate_proj"), + &cfg.specific_config.quant, + )?; + let up_proj = linear_no_bias( + hidden_sz, + intermediate_sz, + vb.pp("up_proj"), + &cfg.specific_config.quant, + )?; + let down_proj = linear_no_bias( + intermediate_sz, + hidden_sz, + vb.pp("down_proj"), + &cfg.specific_config.quant, + )?; Ok(Self { gate_proj, up_proj, @@ -163,10 +185,30 @@ impl Attention { let num_heads = cfg.num_attention_heads; let num_kv_heads = cfg.num_key_value_heads; let head_dim = hidden_sz / num_heads; - let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?; - let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?; - let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?; - let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?; + let q_proj = linear_no_bias( + hidden_sz, + num_heads * head_dim, + vb.pp("q_proj"), + &cfg.specific_config.quant, + )?; + let k_proj = linear_no_bias( + hidden_sz, + num_kv_heads * head_dim, + vb.pp("k_proj"), + &cfg.specific_config.quant, + )?; + let v_proj = linear_no_bias( + hidden_sz, + num_kv_heads * head_dim, + vb.pp("v_proj"), + &cfg.specific_config.quant, + )?; + let o_proj = linear_no_bias( + num_heads * head_dim, + hidden_sz, + vb.pp("o_proj"), + &cfg.specific_config.quant, + )?; Ok(Self { q_proj, k_proj, @@ -322,7 +364,12 @@ impl Mistral { layers.push(layer) } let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?; - let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; + let lm_head = linear_no_bias( + cfg.hidden_size, + cfg.vocab_size, + vb.pp("lm_head"), + &cfg.specific_config.quant, + )?; Ok(Self { embed_tokens, layers, diff --git a/src/openai/models/mod.rs b/src/openai/models/mod.rs index ac7d85d..1cb94e6 100644 --- a/src/openai/models/mod.rs +++ b/src/openai/models/mod.rs @@ -7,6 +7,7 @@ pub mod phi3; pub mod qwen2; pub mod stable_lm; pub mod yi; +use crate::SpecificConfig; use candle_core::DType; use either::Either; use serde::Deserialize; @@ -45,6 +46,7 @@ pub struct Config { pub kv_cache_dtype: DType, pub use_qkv_bias: Option, pub custom_stop_tokens: Option>, + pub specific_config: SpecificConfig, } impl Config { diff --git a/src/openai/models/phi2.rs b/src/openai/models/phi2.rs index fa9e2d5..4ee25a6 100644 --- a/src/openai/models/phi2.rs +++ b/src/openai/models/phi2.rs @@ -1,7 +1,8 @@ use super::Config; -use crate::openai::models::linear::{linear_no_bias as linear, Linear}; +use crate::openai::models::linear::{linear_no_bias_x as linear, LinearX as Linear}; use crate::paged_attention::input_metadata::InputMetadata; use crate::paged_attention::PagedAttention; +use crate::SpecificConfig; use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::{Activation, VarBuilder}; use candle_transformers::models::with_tracing::{layer_norm, Embedding, LayerNorm}; @@ -32,7 +33,12 @@ pub struct Phi2Config { } impl Phi2Config { - pub fn into_config(self, use_flash_attn: bool, kv_cache_dtype: DType) -> Config { + pub fn into_config( + self, + use_flash_attn: bool, + kv_cache_dtype: DType, + scfg: &SpecificConfig, + ) -> Config { Config { hidden_size: self.hidden_size, intermediate_size: self.intermediate_size, @@ -57,6 +63,7 @@ impl Phi2Config { kv_cache_dtype, use_qkv_bias: None, custom_stop_tokens: None, + specific_config: scfg.clone(), } } } @@ -115,8 +122,18 @@ struct MLP { impl MLP { fn new(cfg: &Config, vb: VarBuilder) -> Result { - let fc1 = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("fc1"))?; - let fc2 = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("fc2"))?; + let fc1 = linear( + cfg.hidden_size, + cfg.intermediate_size, + vb.pp("fc1"), + &cfg.specific_config.quant, + )?; + let fc2 = linear( + cfg.intermediate_size, + cfg.hidden_size, + vb.pp("fc2"), + &cfg.specific_config.quant, + )?; Ok(Self { fc1, fc2, @@ -153,10 +170,30 @@ impl Attention { let num_heads = cfg.num_attention_heads; let num_kv_heads = cfg.num_key_value_heads; let head_dim = cfg.hidden_size / cfg.num_attention_heads; - let q_proj = linear(cfg.hidden_size, num_heads * head_dim, vb.pp("q_proj"))?; - let k_proj = linear(cfg.hidden_size, num_kv_heads * head_dim, vb.pp("k_proj"))?; - let v_proj = linear(cfg.hidden_size, num_kv_heads * head_dim, vb.pp("v_proj"))?; - let dense = linear(num_heads * head_dim, cfg.hidden_size, vb.pp("dense"))?; + let q_proj = linear( + cfg.hidden_size, + num_heads * head_dim, + vb.pp("q_proj"), + &cfg.specific_config.quant, + )?; + let k_proj = linear( + cfg.hidden_size, + num_kv_heads * head_dim, + vb.pp("k_proj"), + &cfg.specific_config.quant, + )?; + let v_proj = linear( + cfg.hidden_size, + num_kv_heads * head_dim, + vb.pp("v_proj"), + &cfg.specific_config.quant, + )?; + let dense = linear( + num_heads * head_dim, + cfg.hidden_size, + vb.pp("dense"), + &cfg.specific_config.quant, + )?; // Alternative rope scalings are not supported. let rotary_emb = RotaryEmbedding::new(cfg, dtype, vb.device())?; let (q_layernorm, k_layernorm) = if cfg.qk_layer_rms_norm.unwrap() { @@ -324,7 +361,12 @@ impl Phi2 { let layer = DecoderLayer::new(cfg, dtype, vb_m.pp(layer_idx))?; layers.push(layer) } - let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; + let lm_head = linear( + cfg.hidden_size, + cfg.vocab_size, + vb.pp("lm_head"), + &cfg.specific_config.quant, + )?; Ok(Self { embed_tokens, layers, diff --git a/src/openai/models/phi3.rs b/src/openai/models/phi3.rs index c45e891..f750129 100644 --- a/src/openai/models/phi3.rs +++ b/src/openai/models/phi3.rs @@ -1,9 +1,10 @@ // This implementation is based on: // https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py use super::{Config, RopeScaling}; -use crate::openai::models::linear::{linear_no_bias as linear, Linear}; +use crate::openai::models::linear::{linear_no_bias_x as linear, LinearX as Linear}; use crate::paged_attention::input_metadata::InputMetadata; use crate::paged_attention::PagedAttention; +use crate::SpecificConfig; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_core as candle; use candle_nn::VarBuilder; @@ -33,7 +34,12 @@ pub struct PhiConfig { } impl PhiConfig { - pub fn into_config(self, use_flash_attn: bool, kv_cache_dtype: DType) -> Config { + pub fn into_config( + self, + use_flash_attn: bool, + kv_cache_dtype: DType, + scfg: &SpecificConfig, + ) -> Config { Config { hidden_size: self.hidden_size, intermediate_size: self.intermediate_size, @@ -58,6 +64,7 @@ impl PhiConfig { kv_cache_dtype, use_qkv_bias: None, custom_stop_tokens: None, + specific_config: scfg.clone(), } } } @@ -235,8 +242,18 @@ impl Attention { let num_kv_heads = cfg.num_key_value_heads; let head_dim = cfg.hidden_size / cfg.num_attention_heads; let op_size = num_heads * head_dim + 2 * num_kv_heads * head_dim; - let qkv_proj = linear(cfg.hidden_size, op_size, vb.pp("qkv_proj"))?; - let o_proj = linear(num_heads * head_dim, cfg.hidden_size, vb.pp("o_proj"))?; + let qkv_proj = linear( + cfg.hidden_size, + op_size, + vb.pp("qkv_proj"), + &cfg.specific_config.quant, + )?; + let o_proj = linear( + num_heads * head_dim, + cfg.hidden_size, + vb.pp("o_proj"), + &cfg.specific_config.quant, + )?; Ok(Self { qkv_proj, o_proj, @@ -340,8 +357,18 @@ impl Mlp { fn new(cfg: &Config, vb: VarBuilder) -> Result { let hidden_size = cfg.hidden_size; let i_size = cfg.intermediate_size; - let gate_up_proj = linear(hidden_size, 2 * i_size, vb.pp("gate_up_proj"))?; - let down_proj = linear(i_size, hidden_size, vb.pp("down_proj"))?; + let gate_up_proj = linear( + hidden_size, + 2 * i_size, + vb.pp("gate_up_proj"), + &cfg.specific_config.quant, + )?; + let down_proj = linear( + i_size, + hidden_size, + vb.pp("down_proj"), + &cfg.specific_config.quant, + )?; Ok(Self { gate_up_proj, down_proj, @@ -430,7 +457,12 @@ impl Phi { layers.push(layer) } let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?; - let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; + let lm_head = linear( + cfg.hidden_size, + cfg.vocab_size, + vb.pp("lm_head"), + &cfg.specific_config.quant, + )?; Ok(Self { embed_tokens, layers, diff --git a/src/openai/models/qwen2.rs b/src/openai/models/qwen2.rs index 40fd59d..8b835b2 100644 --- a/src/openai/models/qwen2.rs +++ b/src/openai/models/qwen2.rs @@ -1,7 +1,10 @@ use super::Config; -use crate::openai::models::linear::{linear, linear_no_bias, Linear}; +use crate::openai::models::linear::{ + linear_no_bias_x as linear_no_bias, linear_x as linear, LinearX as Linear, +}; use crate::paged_attention::input_metadata::InputMetadata; use crate::paged_attention::PagedAttention; +use crate::SpecificConfig; use candle::{DType, Device, IndexOp, Module, Result, Tensor}; use candle_core as candle; use candle_nn::VarBuilder; @@ -31,7 +34,12 @@ pub struct QwenConfig { } impl QwenConfig { - pub fn into_config(self, use_flash_attn: bool, kv_cache_dtype: DType) -> Config { + pub fn into_config( + self, + use_flash_attn: bool, + kv_cache_dtype: DType, + scfg: &SpecificConfig, + ) -> Config { Config { hidden_size: self.hidden_size, intermediate_size: self.intermediate_size, @@ -56,6 +64,7 @@ impl QwenConfig { kv_cache_dtype, use_qkv_bias: None, custom_stop_tokens: None, + specific_config: scfg.clone(), } } } @@ -125,9 +134,24 @@ impl MLP { fn new(cfg: &Config, vb: VarBuilder) -> Result { let hidden_sz = cfg.hidden_size; let intermediate_sz = cfg.intermediate_size; - let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?; - let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?; - let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?; + let gate_proj = linear_no_bias( + hidden_sz, + intermediate_sz, + vb.pp("gate_proj"), + &cfg.specific_config.quant, + )?; + let up_proj = linear_no_bias( + hidden_sz, + intermediate_sz, + vb.pp("up_proj"), + &cfg.specific_config.quant, + )?; + let down_proj = linear_no_bias( + intermediate_sz, + hidden_sz, + vb.pp("down_proj"), + &cfg.specific_config.quant, + )?; Ok(Self { gate_proj, up_proj, @@ -164,10 +188,30 @@ impl Attention { let num_heads = cfg.num_attention_heads; let num_kv_heads = cfg.num_key_value_heads; let head_dim = hidden_sz / num_heads; - let q_proj = linear(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?; - let k_proj = linear(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?; - let v_proj = linear(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?; - let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?; + let q_proj = linear( + hidden_sz, + num_heads * head_dim, + vb.pp("q_proj"), + &cfg.specific_config.quant, + )?; + let k_proj = linear( + hidden_sz, + num_kv_heads * head_dim, + vb.pp("k_proj"), + &cfg.specific_config.quant, + )?; + let v_proj = linear( + hidden_sz, + num_kv_heads * head_dim, + vb.pp("v_proj"), + &cfg.specific_config.quant, + )?; + let o_proj = linear_no_bias( + num_heads * head_dim, + hidden_sz, + vb.pp("o_proj"), + &cfg.specific_config.quant, + )?; Ok(Self { q_proj, k_proj, @@ -330,6 +374,7 @@ impl Qwen2 { } else { vb.pp("lm_head") }, + &cfg.specific_config.quant, )?; Ok(Self { embed_tokens, diff --git a/src/openai/models/stable_lm.rs b/src/openai/models/stable_lm.rs index b2acc54..9388c28 100644 --- a/src/openai/models/stable_lm.rs +++ b/src/openai/models/stable_lm.rs @@ -1,7 +1,10 @@ use super::Config; -use crate::openai::models::linear::{linear, linear_no_bias, Linear}; +use crate::openai::models::linear::{ + linear_no_bias_x as linear_no_bias, linear_x as linear, LinearX as Linear, +}; use crate::paged_attention::input_metadata::InputMetadata; use crate::paged_attention::PagedAttention; +use crate::SpecificConfig; use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::{Activation, LayerNorm, VarBuilder}; use either::Either; @@ -31,7 +34,12 @@ pub struct StableLMConfig { } impl StableLMConfig { - pub fn into_config(self, use_flash_attn: bool, kv_cache_dtype: DType) -> Config { + pub fn into_config( + self, + use_flash_attn: bool, + kv_cache_dtype: DType, + scfg: &SpecificConfig, + ) -> Config { Config { hidden_size: self.hidden_size, intermediate_size: self.intermediate_size, @@ -59,6 +67,7 @@ impl StableLMConfig { kv_cache_dtype, use_qkv_bias: Some(self.use_qkv_bias.unwrap_or(false)), custom_stop_tokens: None, + specific_config: scfg.clone(), } } } @@ -125,9 +134,24 @@ impl MLP { fn new(cfg: &Config, vb: VarBuilder) -> Result { let hidden_sz = cfg.hidden_size; let intermediate_sz = cfg.intermediate_size; - let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?; - let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?; - let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?; + let gate_proj = linear_no_bias( + hidden_sz, + intermediate_sz, + vb.pp("gate_proj"), + &cfg.specific_config.quant, + )?; + let up_proj = linear_no_bias( + hidden_sz, + intermediate_sz, + vb.pp("up_proj"), + &cfg.specific_config.quant, + )?; + let down_proj = linear_no_bias( + intermediate_sz, + hidden_sz, + vb.pp("down_proj"), + &cfg.specific_config.quant, + )?; Ok(Self { gate_proj, up_proj, @@ -173,10 +197,30 @@ impl Attention { linear_no_bias }; - let q_proj = linear_layer(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?; - let k_proj = linear_layer(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?; - let v_proj = linear_layer(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?; - let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?; + let q_proj = linear_layer( + hidden_sz, + num_heads * head_dim, + vb.pp("q_proj"), + &cfg.specific_config.quant, + )?; + let k_proj = linear_layer( + hidden_sz, + num_kv_heads * head_dim, + vb.pp("k_proj"), + &cfg.specific_config.quant, + )?; + let v_proj = linear_layer( + hidden_sz, + num_kv_heads * head_dim, + vb.pp("v_proj"), + &cfg.specific_config.quant, + )?; + let o_proj = linear_no_bias( + num_heads * head_dim, + hidden_sz, + vb.pp("o_proj"), + &cfg.specific_config.quant, + )?; Ok(Self { q_proj, k_proj, @@ -333,7 +377,12 @@ impl StableLM { layers.push(layer) } let norm = candle_nn::layer_norm(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?; - let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; + let lm_head = linear_no_bias( + cfg.hidden_size, + cfg.vocab_size, + vb.pp("lm_head"), + &cfg.specific_config.quant, + )?; Ok(Self { embed_tokens, layers, diff --git a/src/openai/models/yi.rs b/src/openai/models/yi.rs index 20de25a..3ff4ecd 100644 --- a/src/openai/models/yi.rs +++ b/src/openai/models/yi.rs @@ -1,7 +1,8 @@ use super::Config; -use crate::openai::models::linear::{linear_no_bias, Linear}; +use crate::openai::models::linear::{linear_no_bias_x as linear_no_bias, LinearX as Linear}; use crate::paged_attention::input_metadata::InputMetadata; use crate::paged_attention::PagedAttention; +use crate::SpecificConfig; use candle_core::{DType, Device, IndexOp, Module, Result, Tensor}; use candle_nn::{Activation, VarBuilder}; use candle_transformers::models::with_tracing::RmsNorm; @@ -28,7 +29,12 @@ pub struct YiConfig { } impl YiConfig { - pub fn into_config(self, use_flash_attn: bool, kv_cache_dtype: DType) -> Config { + pub fn into_config( + self, + use_flash_attn: bool, + kv_cache_dtype: DType, + scfg: &SpecificConfig, + ) -> Config { Config { hidden_size: self.hidden_size, intermediate_size: self.intermediate_size, @@ -53,6 +59,7 @@ impl YiConfig { kv_cache_dtype, use_qkv_bias: None, custom_stop_tokens: Some(vec!["<|im_end|>".to_string()]), + specific_config: scfg.clone(), } } } @@ -123,9 +130,24 @@ impl MLP { fn new(cfg: &Config, vb: VarBuilder) -> Result { let hidden_sz = cfg.hidden_size; let intermediate_sz = cfg.intermediate_size; - let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?; - let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?; - let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?; + let gate_proj = linear_no_bias( + hidden_sz, + intermediate_sz, + vb.pp("gate_proj"), + &cfg.specific_config.quant, + )?; + let up_proj = linear_no_bias( + hidden_sz, + intermediate_sz, + vb.pp("up_proj"), + &cfg.specific_config.quant, + )?; + let down_proj = linear_no_bias( + intermediate_sz, + hidden_sz, + vb.pp("down_proj"), + &cfg.specific_config.quant, + )?; Ok(Self { gate_proj, up_proj, @@ -162,10 +184,30 @@ impl Attention { let num_heads = cfg.num_attention_heads; let num_kv_heads = cfg.num_key_value_heads; let head_dim = hidden_sz / num_heads; - let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?; - let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?; - let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?; - let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?; + let q_proj = linear_no_bias( + hidden_sz, + num_heads * head_dim, + vb.pp("q_proj"), + &cfg.specific_config.quant, + )?; + let k_proj = linear_no_bias( + hidden_sz, + num_kv_heads * head_dim, + vb.pp("k_proj"), + &cfg.specific_config.quant, + )?; + let v_proj = linear_no_bias( + hidden_sz, + num_kv_heads * head_dim, + vb.pp("v_proj"), + &cfg.specific_config.quant, + )?; + let o_proj = linear_no_bias( + num_heads * head_dim, + hidden_sz, + vb.pp("o_proj"), + &cfg.specific_config.quant, + )?; Ok(Self { q_proj, k_proj, @@ -319,7 +361,12 @@ impl Yi { layers.push(layer) } let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?; - let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; + let lm_head = linear_no_bias( + cfg.hidden_size, + cfg.vocab_size, + vb.pp("lm_head"), + &cfg.specific_config.quant, + )?; Ok(Self { embed_tokens, layers, diff --git a/src/openai/pipelines/pipeline.rs b/src/openai/pipelines/pipeline.rs index 38ab155..2bf5340 100644 --- a/src/openai/pipelines/pipeline.rs +++ b/src/openai/pipelines/pipeline.rs @@ -26,7 +26,7 @@ use crate::{ PipelineConfig, }, paged_attention::input_metadata::InputMetadata, - try_api, + try_api, SpecificConfig, }; use candle_core::{DType, Device, IndexOp, Tensor}; use candle_examples::token_output_stream::TokenOutputStream; @@ -42,37 +42,6 @@ const EOS_TOKEN: &str = ""; const SAMPLING_SEED: u64 = 299792458; const MIN_GEN_TOKENS: usize = 128; const MAX_GEN_TOKENS: usize = 4096; - -#[derive(Debug, Clone)] -pub struct SpecificConfig { - repeat_last_n: Option, - temperature: Option, - top_k: Option, - top_p: Option, - penalty: Option, - max_gen_tokens: Option, -} - -impl SpecificConfig { - pub fn new( - repeat_last_n: Option, - temperature: Option, - top_k: Option, - top_p: Option, - penalty: Option, - max_gen_tokens: Option, - ) -> Self { - Self { - repeat_last_n, - temperature, - top_k, - top_p, - penalty, - max_gen_tokens, - } - } -} - enum LLMModel { LLAMA(Llama), Phi2(Phi2), @@ -176,50 +145,50 @@ impl ModelLoader for DefaultLoader { let config: LlamaConfig = try_api!(serde_json::from_slice(&try_api!( std::fs::read(paths.get_config_filename()) ),)); - config.into_config(false, dtype) + config.into_config(false, dtype, &specific_args) } "phi2" => { let config: Phi2Config = try_api!(serde_json::from_slice(&try_api!( std::fs::read(paths.get_config_filename()) ),)); //Phi2 use F32 type for kvcache - config.into_config(false, DType::F32) + config.into_config(false, DType::F32, &specific_args) } "phi3" => { let config: PhiConfig = try_api!(serde_json::from_slice(&try_api!(std::fs::read( paths.get_config_filename() )),)); - config.into_config(false, dtype) + config.into_config(false, dtype, &specific_args) } "qwen2" => { let config: QwenConfig = try_api!(serde_json::from_slice(&try_api!( std::fs::read(paths.get_config_filename()) ),)); - config.into_config(false, dtype) + config.into_config(false, dtype, &specific_args) } "gemma" => { let config: GemmaConfig = try_api!(serde_json::from_slice(&try_api!( std::fs::read(paths.get_config_filename()) ),)); - config.into_config(false, dtype) + config.into_config(false, dtype, &specific_args) } "mistral" => { let config: MistralConfig = try_api!(serde_json::from_slice(&try_api!( std::fs::read(paths.get_config_filename()) ),)); - config.into_config(false, dtype) + config.into_config(false, dtype, &specific_args) } "yi" => { let config: YiConfig = try_api!(serde_json::from_slice(&try_api!(std::fs::read( paths.get_config_filename() )),)); - config.into_config(false, dtype) + config.into_config(false, dtype, &specific_args) } "stablelm" => { let config: StableLMConfig = try_api!(serde_json::from_slice(&try_api!( std::fs::read(paths.get_config_filename()) ),)); - config.into_config(false, dtype) + config.into_config(false, dtype, &specific_args) } _ => panic!("Model not supported!"), }; diff --git a/tests/tests.rs b/tests/tests.rs index 7a65c69..ccd9842 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -24,6 +24,7 @@ async fn test_llama() -> Result<(), APIError> { penalty: Some(1.1), temperature: None, max_gen_tokens: Some(512), + quant: None, }, Some("meta-llama/Llama-2-7b-chat-hf".to_string()), );