From a3404b396dd2afe727470cbcb64b246846948fc7 Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Sun, 28 Apr 2024 21:26:03 -0400 Subject: [PATCH 01/11] Sliding window for phi3 --- mistralrs-core/src/models/phi3.rs | 24 +++++++++++++++++++++--- mistralrs-core/src/pipeline/loaders.rs | 2 ++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/mistralrs-core/src/models/phi3.rs b/mistralrs-core/src/models/phi3.rs index 3de891a6b..077e3b272 100644 --- a/mistralrs-core/src/models/phi3.rs +++ b/mistralrs-core/src/models/phi3.rs @@ -33,6 +33,7 @@ pub struct Config { pub rope_scaling: Option, String>>>, pub max_position_embeddings: usize, pub use_flash_attn: bool, + pub sliding_window: Option, pub original_max_position_embeddings: usize, } @@ -277,6 +278,7 @@ pub struct Model { pub cache: Cache, pub max_seq_len: usize, mapper: Box, + sliding_window: Option, } impl Model { @@ -333,6 +335,7 @@ impl Model { cache: Cache::new(cfg.num_hidden_layers, false), max_seq_len: cfg.max_position_embeddings, mapper, + sliding_window: cfg.sliding_window, }) } @@ -341,9 +344,20 @@ impl Model { b_size: usize, tgt_len: usize, seqlen_offset: usize, + sliding_window: Option, ) -> Result { + // Sliding window mask + let sliding_window = sliding_window.unwrap_or(tgt_len + 1); let mask: Vec<_> = (0..tgt_len) - .flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. })) + .flat_map(|i| { + (0..tgt_len).map(move |j| { + if i < j || j + sliding_window < i { + f32::NEG_INFINITY + } else { + 0. + } + }) + }) .collect(); let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?; let mask = if seqlen_offset > 0 { @@ -384,8 +398,12 @@ impl Model { let attention_mask = if seq_len <= 1 { None } else { - let mask = - self.prepare_decoder_attention_mask(b_size, seq_len, past_key_values_length)?; + let mask = self.prepare_decoder_attention_mask( + b_size, + seq_len, + past_key_values_length, + self.sliding_window, + )?; Some(mask) }; let mut xs = self.embed_tokens.forward(input_ids)?; diff --git a/mistralrs-core/src/pipeline/loaders.rs b/mistralrs-core/src/pipeline/loaders.rs index 8b3750153..cfcc20565 100644 --- a/mistralrs-core/src/pipeline/loaders.rs +++ b/mistralrs-core/src/pipeline/loaders.rs @@ -520,6 +520,7 @@ struct Phi3BasicConfig { rope_scaling: Option>, max_position_embeddings: usize, original_max_position_embeddings: usize, + sliding_window: Option, } impl Phi3BasicConfig { @@ -545,6 +546,7 @@ impl Phi3BasicConfig { }), original_max_position_embeddings: basic_config.original_max_position_embeddings, use_flash_attn, + sliding_window: basic_config.sliding_window, }) } } From 92cb885c507bf9279a8212ed9970c16a6d9aab13 Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Sun, 28 Apr 2024 22:22:00 -0400 Subject: [PATCH 02/11] Slice and dice --- mistralrs-core/src/models/phi3.rs | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/mistralrs-core/src/models/phi3.rs b/mistralrs-core/src/models/phi3.rs index 077e3b272..306f8461c 100644 --- a/mistralrs-core/src/models/phi3.rs +++ b/mistralrs-core/src/models/phi3.rs @@ -53,6 +53,7 @@ struct Attention { head_dim: usize, rotary_emb: Arc, use_flash_attn: bool, + sliding_window: Option, } impl Attention { @@ -72,6 +73,7 @@ impl Attention { num_kv_groups: num_heads / num_kv_heads, head_dim, use_flash_attn: cfg.use_flash_attn, + sliding_window: cfg.sliding_window, }) } @@ -115,12 +117,24 @@ impl Attention { let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?; - let (k, v) = match &*kv_cache { - None => (k, v), - Some((prev_k, prev_v)) => { - let k = Tensor::cat(&[prev_k, &k], 2)?; - let v = Tensor::cat(&[prev_v, &v], 2)?; - (k, v) + let (k, v, attn_mask) = match kv_cache.clone() { + None => (k, v, attention_mask.cloned()), + Some((mut prev_k, mut prev_v)) => { + let mut mask = attention_mask.cloned(); + if let Some(sliding_window) = self.sliding_window { + let kv_seq_len = prev_k.dim(2)?; + if kv_seq_len > sliding_window { + let slicing_tokens = 1 - sliding_window; + prev_k = prev_k.narrow(2, slicing_tokens, prev_k.dim(2)?)?; + prev_v = prev_v.narrow(2, slicing_tokens, prev_k.dim(2)?)?; + if let Some(ref mut mask) = mask { + *mask = mask.narrow(1, slicing_tokens, mask.dim(1)?)? + } + } + } + let k = Tensor::cat(&[prev_k, k], 2)?; + let v = Tensor::cat(&[prev_v, v], 2)?; + (k, v, mask) } }; *kv_cache = Some((k.clone(), v.clone())); @@ -139,9 +153,9 @@ impl Attention { let scale = 1f64 / f64::sqrt(self.head_dim as f64); let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?; - let attn_weights = match attention_mask { + let attn_weights = match attn_mask { None => attn_weights, - Some(mask) => attn_weights.broadcast_add(mask)?, + Some(mask) => attn_weights.broadcast_add(&mask)?, }; let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; attn_weights.matmul(&v)? From c97a2898408545263908f64146aed87b9f94db97 Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Sun, 28 Apr 2024 22:29:25 -0400 Subject: [PATCH 03/11] Slice and dice without underflow --- mistralrs-core/src/models/phi3.rs | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/mistralrs-core/src/models/phi3.rs b/mistralrs-core/src/models/phi3.rs index 306f8461c..acd7db7ef 100644 --- a/mistralrs-core/src/models/phi3.rs +++ b/mistralrs-core/src/models/phi3.rs @@ -124,11 +124,20 @@ impl Attention { if let Some(sliding_window) = self.sliding_window { let kv_seq_len = prev_k.dim(2)?; if kv_seq_len > sliding_window { - let slicing_tokens = 1 - sliding_window; - prev_k = prev_k.narrow(2, slicing_tokens, prev_k.dim(2)?)?; - prev_v = prev_v.narrow(2, slicing_tokens, prev_k.dim(2)?)?; + prev_k = prev_k.narrow( + 2, + kv_seq_len - (sliding_window - 1), + sliding_window - 1, + )?; + prev_v = prev_v.narrow( + 2, + kv_seq_len - (sliding_window - 1), + sliding_window - 1, + )?; if let Some(ref mut mask) = mask { - *mask = mask.narrow(1, slicing_tokens, mask.dim(1)?)? + let mask_len = mask.dim(1)?; + *mask = + mask.narrow(1, mask_len - (sliding_window - 1), sliding_window - 1)? } } } From 6402255f7c135781ec2a3e50507b00bd53799126 Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Sun, 28 Apr 2024 22:36:38 -0400 Subject: [PATCH 04/11] Propery do the attention mask --- mistralrs-core/src/models/phi3.rs | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/mistralrs-core/src/models/phi3.rs b/mistralrs-core/src/models/phi3.rs index acd7db7ef..daf36b7c0 100644 --- a/mistralrs-core/src/models/phi3.rs +++ b/mistralrs-core/src/models/phi3.rs @@ -136,8 +136,15 @@ impl Attention { )?; if let Some(ref mut mask) = mask { let mask_len = mask.dim(1)?; - *mask = - mask.narrow(1, mask_len - (sliding_window - 1), sliding_window - 1)? + *mask = mask.narrow( + 1, + mask_len - (sliding_window - 1), + sliding_window - 1, + )?; + *mask = Tensor::cat( + &[&*mask, &mask.narrow(1, mask_len - 1, 1)?], + D::Minus1, + )?; } } } From 6cb2f71bef499e2d34ace3a2bf6b7273385faba3 Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Sun, 28 Apr 2024 22:38:19 -0400 Subject: [PATCH 05/11] Propery do the attention mask --- mistralrs-core/src/models/phi3.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/phi3.rs b/mistralrs-core/src/models/phi3.rs index daf36b7c0..891e97084 100644 --- a/mistralrs-core/src/models/phi3.rs +++ b/mistralrs-core/src/models/phi3.rs @@ -142,7 +142,7 @@ impl Attention { sliding_window - 1, )?; *mask = Tensor::cat( - &[&*mask, &mask.narrow(1, mask_len - 1, 1)?], + &[&*mask, &mask.narrow(1, mask_len - 1, 1)?.ones_like()?], D::Minus1, )?; } From 6344b92ebee18f1c071094121c8c603c4668ce6f Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Mon, 29 Apr 2024 06:09:54 -0400 Subject: [PATCH 06/11] Keep rope in f32 during init long ctxts --- mistralrs-core/src/layers.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mistralrs-core/src/layers.rs b/mistralrs-core/src/layers.rs index 7dc7a6369..3babd9439 100644 --- a/mistralrs-core/src/layers.rs +++ b/mistralrs-core/src/layers.rs @@ -167,13 +167,13 @@ impl PhiRotaryEmbedding { .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32) .collect(); let inv_freq_len = inv_freq.len(); - let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?; let t = Tensor::arange(0u32, max_seq_len as u32, dev)? .to_dtype(dtype)? .reshape((max_seq_len, 1))?; let freqs = t.matmul(&inv_freq)?; - let sin = freqs.sin()?; - let cos = freqs.cos()?; + let sin = freqs.sin()?.to_dtype(dtype)?; + let cos = freqs.cos()?.to_dtype(dtype)?; Ok(Self { short_cos: cos, short_sin: sin, From 9c62240261781b49c36a71b715ab16b7d43babfa Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Mon, 29 Apr 2024 06:11:23 -0400 Subject: [PATCH 07/11] Keep rope in f32 during init long ctxts --- mistralrs-core/src/layers.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/layers.rs b/mistralrs-core/src/layers.rs index 3babd9439..43b2a0bbc 100644 --- a/mistralrs-core/src/layers.rs +++ b/mistralrs-core/src/layers.rs @@ -169,7 +169,7 @@ impl PhiRotaryEmbedding { let inv_freq_len = inv_freq.len(); let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?; let t = Tensor::arange(0u32, max_seq_len as u32, dev)? - .to_dtype(dtype)? + .to_dtype(DType::F32)? .reshape((max_seq_len, 1))?; let freqs = t.matmul(&inv_freq)?; let sin = freqs.sin()?.to_dtype(dtype)?; From cfe480cf6b50dc593d37d9e5d587196a6c1227e0 Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Mon, 29 Apr 2024 06:17:09 -0400 Subject: [PATCH 08/11] Keep rope in f32 for 128k longrope --- mistralrs-core/src/layers.rs | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/mistralrs-core/src/layers.rs b/mistralrs-core/src/layers.rs index 43b2a0bbc..c19f1908c 100644 --- a/mistralrs-core/src/layers.rs +++ b/mistralrs-core/src/layers.rs @@ -137,22 +137,20 @@ impl PhiRotaryEmbedding { let inv_freq_len = inv_freq_long.len(); let t = Tensor::arange(0u32, max_seq_len as u32, dev)? - .to_dtype(dtype)? + .to_dtype(DType::F32)? .reshape((max_seq_len, 1))?; // Calculate sin,cos for long - let inv_freq_long = - Tensor::from_vec(inv_freq_long, (1, inv_freq_len), dev)?.to_dtype(dtype)?; + let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len), dev)?; let freqs_long = t.matmul(&inv_freq_long)?; - let long_sin = freqs_long.sin()?.mul(scaling_factor)?; - let long_cos = freqs_long.cos()?.mul(scaling_factor)?; + let long_sin = freqs_long.sin()?.mul(scaling_factor)?.to_dtype(dtype)?; + let long_cos = freqs_long.cos()?.mul(scaling_factor)?.to_dtype(dtype)?; // Calculate sin,cos for short - let inv_freq_short = - Tensor::from_vec(inv_freq_short, (1, inv_freq_len), dev)?.to_dtype(dtype)?; + let inv_freq_short = Tensor::from_vec(inv_freq_short, (1, inv_freq_len), dev)?; let freqs_short = t.matmul(&inv_freq_short)?; - let short_sin = freqs_short.sin()?.mul(scaling_factor)?; - let short_cos = freqs_short.cos()?.mul(scaling_factor)?; + let short_sin = freqs_short.sin()?.mul(scaling_factor)?.to_dtype(dtype)?; + let short_cos = freqs_short.cos()?.mul(scaling_factor)?.to_dtype(dtype)?; Ok(Self { short_cos, From f56fcd0e9c10258eac4033a8ec8e75d5772c56f0 Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Mon, 29 Apr 2024 06:23:20 -0400 Subject: [PATCH 09/11] Implement for xlora phi3 --- mistralrs-core/src/xlora_models/phi3.rs | 69 +++++++++++++++++++++---- 1 file changed, 58 insertions(+), 11 deletions(-) diff --git a/mistralrs-core/src/xlora_models/phi3.rs b/mistralrs-core/src/xlora_models/phi3.rs index 13080c8d1..2456f2770 100644 --- a/mistralrs-core/src/xlora_models/phi3.rs +++ b/mistralrs-core/src/xlora_models/phi3.rs @@ -31,6 +31,7 @@ struct Attention { head_dim: usize, rotary_emb: Arc, use_flash_attn: bool, + sliding_window: Option, } impl Attention { @@ -77,6 +78,7 @@ impl Attention { num_kv_groups: num_heads / num_kv_heads, head_dim, use_flash_attn: cfg.use_flash_attn, + sliding_window: cfg.sliding_window, }) } @@ -129,12 +131,40 @@ impl Attention { let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?; - let (k, v) = match &*kv_cache { - None => (k, v), - Some((prev_k, prev_v)) => { - let k = Tensor::cat(&[prev_k, &k], 2)?; - let v = Tensor::cat(&[prev_v, &v], 2)?; - (k, v) + let (k, v, attn_mask) = match kv_cache.clone() { + None => (k, v, attention_mask.cloned()), + Some((mut prev_k, mut prev_v)) => { + let mut mask = attention_mask.cloned(); + if let Some(sliding_window) = self.sliding_window { + let kv_seq_len = prev_k.dim(2)?; + if kv_seq_len > sliding_window { + prev_k = prev_k.narrow( + 2, + kv_seq_len - (sliding_window - 1), + sliding_window - 1, + )?; + prev_v = prev_v.narrow( + 2, + kv_seq_len - (sliding_window - 1), + sliding_window - 1, + )?; + if let Some(ref mut mask) = mask { + let mask_len = mask.dim(1)?; + *mask = mask.narrow( + 1, + mask_len - (sliding_window - 1), + sliding_window - 1, + )?; + *mask = Tensor::cat( + &[&*mask, &mask.narrow(1, mask_len - 1, 1)?.ones_like()?], + D::Minus1, + )?; + } + } + } + let k = Tensor::cat(&[prev_k, k], 2)?; + let v = Tensor::cat(&[prev_v, v], 2)?; + (k, v, mask) } }; *kv_cache = Some((k.clone(), v.clone())); @@ -153,9 +183,9 @@ impl Attention { let scale = 1f64 / f64::sqrt(self.head_dim as f64); let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?; - let attn_weights = match attention_mask { + let attn_weights = match attn_mask { None => attn_weights, - Some(mask) => attn_weights.broadcast_add(mask)?, + Some(mask) => attn_weights.broadcast_add(&mask)?, }; let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; attn_weights.matmul(&v)? @@ -365,6 +395,7 @@ pub struct Model { pub max_seq_len: usize, mapper: Box, xlora_classifier: Option, + sliding_window: Option, } impl Model { @@ -451,6 +482,7 @@ impl Model { xlora_classifier: xlora_config.map(|xlora_config| { XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap() }), + sliding_window: cfg.sliding_window, }) } @@ -459,9 +491,20 @@ impl Model { b_size: usize, tgt_len: usize, seqlen_offset: usize, + sliding_window: Option, ) -> Result { + // Sliding window mask + let sliding_window = sliding_window.unwrap_or(tgt_len + 1); let mask: Vec<_> = (0..tgt_len) - .flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. })) + .flat_map(|i| { + (0..tgt_len).map(move |j| { + if i < j || j + sliding_window < i { + f32::NEG_INFINITY + } else { + 0. + } + }) + }) .collect(); let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?; let mask = if seqlen_offset > 0 { @@ -506,8 +549,12 @@ impl Model { let attention_mask = if seq_len <= 1 { None } else { - let mask = - self.prepare_decoder_attention_mask(b_size, seq_len, past_key_values_length)?; + let mask = self.prepare_decoder_attention_mask( + b_size, + seq_len, + past_key_values_length, + self.sliding_window, + )?; Some(mask) }; let mut xs = self.embed_tokens.forward(input_ids)?; From e765d873b9c168808b060b6f9343317bca69a594 Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Mon, 29 Apr 2024 06:26:56 -0400 Subject: [PATCH 10/11] Implement for mistral models --- mistralrs-core/src/models/mistral.rs | 46 ++++++++++++++++++---- mistralrs-core/src/xlora_models/mistral.rs | 46 ++++++++++++++++++---- 2 files changed, 76 insertions(+), 16 deletions(-) diff --git a/mistralrs-core/src/models/mistral.rs b/mistralrs-core/src/models/mistral.rs index 244f7f564..2f3a45779 100644 --- a/mistralrs-core/src/models/mistral.rs +++ b/mistralrs-core/src/models/mistral.rs @@ -85,6 +85,7 @@ struct Attention { hidden_size: usize, rotary_emb: Arc, use_flash_attn: bool, + sliding_window: Option, } impl Attention { @@ -110,6 +111,7 @@ impl Attention { hidden_size: hidden_sz, rotary_emb, use_flash_attn: cfg.use_flash_attn, + sliding_window: cfg.sliding_window, }) } @@ -157,12 +159,40 @@ impl Attention { .contiguous()?; } - let (k, v) = match &*kv_cache { - None => (k, v), - Some((prev_k, prev_v)) => { - let k = candle_nn::ops::kvconcat(prev_k, &k, 2)?; - let v = candle_nn::ops::kvconcat(prev_v, &v, 2)?; - (k, v) + let (k, v, attn_mask) = match kv_cache.clone() { + None => (k, v, attention_mask.cloned()), + Some((mut prev_k, mut prev_v)) => { + let mut mask = attention_mask.cloned(); + if let Some(sliding_window) = self.sliding_window { + let kv_seq_len = prev_k.dim(2)?; + if kv_seq_len > sliding_window { + prev_k = prev_k.narrow( + 2, + kv_seq_len - (sliding_window - 1), + sliding_window - 1, + )?; + prev_v = prev_v.narrow( + 2, + kv_seq_len - (sliding_window - 1), + sliding_window - 1, + )?; + if let Some(ref mut mask) = mask { + let mask_len = mask.dim(1)?; + *mask = mask.narrow( + 1, + mask_len - (sliding_window - 1), + sliding_window - 1, + )?; + *mask = Tensor::cat( + &[&*mask, &mask.narrow(1, mask_len - 1, 1)?.ones_like()?], + D::Minus1, + )?; + } + } + } + let k = candle_nn::ops::kvconcat(&prev_k, &k, 2)?; + let v = candle_nn::ops::kvconcat(&prev_v, &v, 2)?; + (k, v, mask) } }; *kv_cache = Some((k.clone(), v.clone())); @@ -181,9 +211,9 @@ impl Attention { let scale = 1f64 / f64::sqrt(self.head_dim as f64); let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?; - let attn_weights = match attention_mask { + let attn_weights = match attn_mask { None => attn_weights, - Some(mask) => attn_weights.broadcast_add(mask)?, + Some(mask) => attn_weights.broadcast_add(&mask)?, }; let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; attn_weights.matmul(&v)? diff --git a/mistralrs-core/src/xlora_models/mistral.rs b/mistralrs-core/src/xlora_models/mistral.rs index 222cc0c23..61bbf07fd 100644 --- a/mistralrs-core/src/xlora_models/mistral.rs +++ b/mistralrs-core/src/xlora_models/mistral.rs @@ -128,6 +128,7 @@ struct Attention { head_dim: usize, rotary_emb: Arc, use_flash_attn: bool, + sliding_window: Option, } impl Attention { @@ -195,6 +196,7 @@ impl Attention { head_dim, rotary_emb, use_flash_attn: cfg.use_flash_attn, + sliding_window: cfg.sliding_window, }) } @@ -261,12 +263,40 @@ impl Attention { .contiguous()?; } - let (k, v) = match &*kv_cache { - None => (k, v), - Some((prev_k, prev_v)) => { - let k = candle_nn::ops::kvconcat(prev_k, &k, 2)?; - let v = candle_nn::ops::kvconcat(prev_v, &v, 2)?; - (k, v) + let (k, v, attn_mask) = match kv_cache.clone() { + None => (k, v, attention_mask.cloned()), + Some((mut prev_k, mut prev_v)) => { + let mut mask = attention_mask.cloned(); + if let Some(sliding_window) = self.sliding_window { + let kv_seq_len = prev_k.dim(2)?; + if kv_seq_len > sliding_window { + prev_k = prev_k.narrow( + 2, + kv_seq_len - (sliding_window - 1), + sliding_window - 1, + )?; + prev_v = prev_v.narrow( + 2, + kv_seq_len - (sliding_window - 1), + sliding_window - 1, + )?; + if let Some(ref mut mask) = mask { + let mask_len = mask.dim(1)?; + *mask = mask.narrow( + 1, + mask_len - (sliding_window - 1), + sliding_window - 1, + )?; + *mask = Tensor::cat( + &[&*mask, &mask.narrow(1, mask_len - 1, 1)?.ones_like()?], + D::Minus1, + )?; + } + } + } + let k = candle_nn::ops::kvconcat(&prev_k, &k, 2)?; + let v = candle_nn::ops::kvconcat(&prev_v, &v, 2)?; + (k, v, mask) } }; *kv_cache = Some((k.clone(), v.clone())); @@ -285,9 +315,9 @@ impl Attention { let scale = 1f64 / f64::sqrt(self.head_dim as f64); let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?; - let attn_weights = match attention_mask { + let attn_weights = match attn_mask { None => attn_weights, - Some(mask) => attn_weights.broadcast_add(mask)?, + Some(mask) => attn_weights.broadcast_add(&mask)?, }; let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; attn_weights.matmul(&v)? From d3439affe8c374bbe399ded14bf6785becc9c70b Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Mon, 29 Apr 2024 06:32:34 -0400 Subject: [PATCH 11/11] Implement for mixtral models --- mistralrs-core/src/models/mixtral.rs | 46 ++++++++++++++++++---- mistralrs-core/src/xlora_models/mixtral.rs | 46 ++++++++++++++++++---- 2 files changed, 76 insertions(+), 16 deletions(-) diff --git a/mistralrs-core/src/models/mixtral.rs b/mistralrs-core/src/models/mixtral.rs index d5797c801..2a1887040 100644 --- a/mistralrs-core/src/models/mixtral.rs +++ b/mistralrs-core/src/models/mixtral.rs @@ -49,6 +49,7 @@ struct Attention { hidden_size: usize, rotary_emb: Arc, use_flash_attn: bool, + sliding_window: Option, } impl Attention { @@ -74,6 +75,7 @@ impl Attention { hidden_size: hidden_sz, rotary_emb, use_flash_attn: cfg.use_flash_attn, + sliding_window: Some(cfg.sliding_window), }) } @@ -121,12 +123,40 @@ impl Attention { .contiguous()?; } - let (k, v) = match &*kv_cache { - None => (k, v), - Some((prev_k, prev_v)) => { - let k = candle_nn::ops::kvconcat(prev_k, &k, 2)?; - let v = candle_nn::ops::kvconcat(prev_v, &v, 2)?; - (k, v) + let (k, v, attn_mask) = match kv_cache.clone() { + None => (k, v, attention_mask.cloned()), + Some((mut prev_k, mut prev_v)) => { + let mut mask = attention_mask.cloned(); + if let Some(sliding_window) = self.sliding_window { + let kv_seq_len = prev_k.dim(2)?; + if kv_seq_len > sliding_window { + prev_k = prev_k.narrow( + 2, + kv_seq_len - (sliding_window - 1), + sliding_window - 1, + )?; + prev_v = prev_v.narrow( + 2, + kv_seq_len - (sliding_window - 1), + sliding_window - 1, + )?; + if let Some(ref mut mask) = mask { + let mask_len = mask.dim(1)?; + *mask = mask.narrow( + 1, + mask_len - (sliding_window - 1), + sliding_window - 1, + )?; + *mask = Tensor::cat( + &[&*mask, &mask.narrow(1, mask_len - 1, 1)?.ones_like()?], + D::Minus1, + )?; + } + } + } + let k = candle_nn::ops::kvconcat(&prev_k, &k, 2)?; + let v = candle_nn::ops::kvconcat(&prev_v, &v, 2)?; + (k, v, mask) } }; *kv_cache = Some((k.clone(), v.clone())); @@ -145,9 +175,9 @@ impl Attention { let scale = 1f64 / f64::sqrt(self.head_dim as f64); let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?; - let attn_weights = match attention_mask { + let attn_weights = match attn_mask { None => attn_weights, - Some(mask) => attn_weights.broadcast_add(mask)?, + Some(mask) => attn_weights.broadcast_add(&mask)?, }; let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; attn_weights.matmul(&v)? diff --git a/mistralrs-core/src/xlora_models/mixtral.rs b/mistralrs-core/src/xlora_models/mixtral.rs index 9b594df56..b6826938d 100644 --- a/mistralrs-core/src/xlora_models/mixtral.rs +++ b/mistralrs-core/src/xlora_models/mixtral.rs @@ -32,6 +32,7 @@ struct Attention { head_dim: usize, rotary_emb: Arc, use_flash_attn: bool, + sliding_window: Option, } impl Attention { @@ -99,6 +100,7 @@ impl Attention { head_dim, rotary_emb, use_flash_attn: cfg.use_flash_attn, + sliding_window: Some(cfg.sliding_window), }) } @@ -165,12 +167,40 @@ impl Attention { .contiguous()?; } - let (k, v) = match &*kv_cache { - None => (k, v), - Some((prev_k, prev_v)) => { - let k = candle_nn::ops::kvconcat(prev_k, &k, 2)?; - let v = candle_nn::ops::kvconcat(prev_v, &v, 2)?; - (k, v) + let (k, v, attn_mask) = match kv_cache.clone() { + None => (k, v, attention_mask.cloned()), + Some((mut prev_k, mut prev_v)) => { + let mut mask = attention_mask.cloned(); + if let Some(sliding_window) = self.sliding_window { + let kv_seq_len = prev_k.dim(2)?; + if kv_seq_len > sliding_window { + prev_k = prev_k.narrow( + 2, + kv_seq_len - (sliding_window - 1), + sliding_window - 1, + )?; + prev_v = prev_v.narrow( + 2, + kv_seq_len - (sliding_window - 1), + sliding_window - 1, + )?; + if let Some(ref mut mask) = mask { + let mask_len = mask.dim(1)?; + *mask = mask.narrow( + 1, + mask_len - (sliding_window - 1), + sliding_window - 1, + )?; + *mask = Tensor::cat( + &[&*mask, &mask.narrow(1, mask_len - 1, 1)?.ones_like()?], + D::Minus1, + )?; + } + } + } + let k = candle_nn::ops::kvconcat(&prev_k, &k, 2)?; + let v = candle_nn::ops::kvconcat(&prev_v, &v, 2)?; + (k, v, mask) } }; *kv_cache = Some((k.clone(), v.clone())); @@ -189,9 +219,9 @@ impl Attention { let scale = 1f64 / f64::sqrt(self.head_dim as f64); let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?; - let attn_weights = match attention_mask { + let attn_weights = match attn_mask { None => attn_weights, - Some(mask) => attn_weights.broadcast_add(mask)?, + Some(mask) => attn_weights.broadcast_add(&mask)?, }; let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; attn_weights.matmul(&v)?