diff --git a/mistralrs-core/src/layers.rs b/mistralrs-core/src/layers.rs index 7dc7a6369..c19f1908c 100644 --- a/mistralrs-core/src/layers.rs +++ b/mistralrs-core/src/layers.rs @@ -137,22 +137,20 @@ impl PhiRotaryEmbedding { let inv_freq_len = inv_freq_long.len(); let t = Tensor::arange(0u32, max_seq_len as u32, dev)? - .to_dtype(dtype)? + .to_dtype(DType::F32)? .reshape((max_seq_len, 1))?; // Calculate sin,cos for long - let inv_freq_long = - Tensor::from_vec(inv_freq_long, (1, inv_freq_len), dev)?.to_dtype(dtype)?; + let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len), dev)?; let freqs_long = t.matmul(&inv_freq_long)?; - let long_sin = freqs_long.sin()?.mul(scaling_factor)?; - let long_cos = freqs_long.cos()?.mul(scaling_factor)?; + let long_sin = freqs_long.sin()?.mul(scaling_factor)?.to_dtype(dtype)?; + let long_cos = freqs_long.cos()?.mul(scaling_factor)?.to_dtype(dtype)?; // Calculate sin,cos for short - let inv_freq_short = - Tensor::from_vec(inv_freq_short, (1, inv_freq_len), dev)?.to_dtype(dtype)?; + let inv_freq_short = Tensor::from_vec(inv_freq_short, (1, inv_freq_len), dev)?; let freqs_short = t.matmul(&inv_freq_short)?; - let short_sin = freqs_short.sin()?.mul(scaling_factor)?; - let short_cos = freqs_short.cos()?.mul(scaling_factor)?; + let short_sin = freqs_short.sin()?.mul(scaling_factor)?.to_dtype(dtype)?; + let short_cos = freqs_short.cos()?.mul(scaling_factor)?.to_dtype(dtype)?; Ok(Self { short_cos, @@ -167,13 +165,13 @@ impl PhiRotaryEmbedding { .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32) .collect(); let inv_freq_len = inv_freq.len(); - let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?; let t = Tensor::arange(0u32, max_seq_len as u32, dev)? - .to_dtype(dtype)? + .to_dtype(DType::F32)? .reshape((max_seq_len, 1))?; let freqs = t.matmul(&inv_freq)?; - let sin = freqs.sin()?; - let cos = freqs.cos()?; + let sin = freqs.sin()?.to_dtype(dtype)?; + let cos = freqs.cos()?.to_dtype(dtype)?; Ok(Self { short_cos: cos, short_sin: sin, diff --git a/mistralrs-core/src/models/mistral.rs b/mistralrs-core/src/models/mistral.rs index 244f7f564..2f3a45779 100644 --- a/mistralrs-core/src/models/mistral.rs +++ b/mistralrs-core/src/models/mistral.rs @@ -85,6 +85,7 @@ struct Attention { hidden_size: usize, rotary_emb: Arc, use_flash_attn: bool, + sliding_window: Option, } impl Attention { @@ -110,6 +111,7 @@ impl Attention { hidden_size: hidden_sz, rotary_emb, use_flash_attn: cfg.use_flash_attn, + sliding_window: cfg.sliding_window, }) } @@ -157,12 +159,40 @@ impl Attention { .contiguous()?; } - let (k, v) = match &*kv_cache { - None => (k, v), - Some((prev_k, prev_v)) => { - let k = candle_nn::ops::kvconcat(prev_k, &k, 2)?; - let v = candle_nn::ops::kvconcat(prev_v, &v, 2)?; - (k, v) + let (k, v, attn_mask) = match kv_cache.clone() { + None => (k, v, attention_mask.cloned()), + Some((mut prev_k, mut prev_v)) => { + let mut mask = attention_mask.cloned(); + if let Some(sliding_window) = self.sliding_window { + let kv_seq_len = prev_k.dim(2)?; + if kv_seq_len > sliding_window { + prev_k = prev_k.narrow( + 2, + kv_seq_len - (sliding_window - 1), + sliding_window - 1, + )?; + prev_v = prev_v.narrow( + 2, + kv_seq_len - (sliding_window - 1), + sliding_window - 1, + )?; + if let Some(ref mut mask) = mask { + let mask_len = mask.dim(1)?; + *mask = mask.narrow( + 1, + mask_len - (sliding_window - 1), + sliding_window - 1, + )?; + *mask = Tensor::cat( + &[&*mask, &mask.narrow(1, mask_len - 1, 1)?.ones_like()?], + D::Minus1, + )?; + } + } + } + let k = candle_nn::ops::kvconcat(&prev_k, &k, 2)?; + let v = candle_nn::ops::kvconcat(&prev_v, &v, 2)?; + (k, v, mask) } }; *kv_cache = Some((k.clone(), v.clone())); @@ -181,9 +211,9 @@ impl Attention { let scale = 1f64 / f64::sqrt(self.head_dim as f64); let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?; - let attn_weights = match attention_mask { + let attn_weights = match attn_mask { None => attn_weights, - Some(mask) => attn_weights.broadcast_add(mask)?, + Some(mask) => attn_weights.broadcast_add(&mask)?, }; let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; attn_weights.matmul(&v)? diff --git a/mistralrs-core/src/models/mixtral.rs b/mistralrs-core/src/models/mixtral.rs index d5797c801..2a1887040 100644 --- a/mistralrs-core/src/models/mixtral.rs +++ b/mistralrs-core/src/models/mixtral.rs @@ -49,6 +49,7 @@ struct Attention { hidden_size: usize, rotary_emb: Arc, use_flash_attn: bool, + sliding_window: Option, } impl Attention { @@ -74,6 +75,7 @@ impl Attention { hidden_size: hidden_sz, rotary_emb, use_flash_attn: cfg.use_flash_attn, + sliding_window: Some(cfg.sliding_window), }) } @@ -121,12 +123,40 @@ impl Attention { .contiguous()?; } - let (k, v) = match &*kv_cache { - None => (k, v), - Some((prev_k, prev_v)) => { - let k = candle_nn::ops::kvconcat(prev_k, &k, 2)?; - let v = candle_nn::ops::kvconcat(prev_v, &v, 2)?; - (k, v) + let (k, v, attn_mask) = match kv_cache.clone() { + None => (k, v, attention_mask.cloned()), + Some((mut prev_k, mut prev_v)) => { + let mut mask = attention_mask.cloned(); + if let Some(sliding_window) = self.sliding_window { + let kv_seq_len = prev_k.dim(2)?; + if kv_seq_len > sliding_window { + prev_k = prev_k.narrow( + 2, + kv_seq_len - (sliding_window - 1), + sliding_window - 1, + )?; + prev_v = prev_v.narrow( + 2, + kv_seq_len - (sliding_window - 1), + sliding_window - 1, + )?; + if let Some(ref mut mask) = mask { + let mask_len = mask.dim(1)?; + *mask = mask.narrow( + 1, + mask_len - (sliding_window - 1), + sliding_window - 1, + )?; + *mask = Tensor::cat( + &[&*mask, &mask.narrow(1, mask_len - 1, 1)?.ones_like()?], + D::Minus1, + )?; + } + } + } + let k = candle_nn::ops::kvconcat(&prev_k, &k, 2)?; + let v = candle_nn::ops::kvconcat(&prev_v, &v, 2)?; + (k, v, mask) } }; *kv_cache = Some((k.clone(), v.clone())); @@ -145,9 +175,9 @@ impl Attention { let scale = 1f64 / f64::sqrt(self.head_dim as f64); let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?; - let attn_weights = match attention_mask { + let attn_weights = match attn_mask { None => attn_weights, - Some(mask) => attn_weights.broadcast_add(mask)?, + Some(mask) => attn_weights.broadcast_add(&mask)?, }; let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; attn_weights.matmul(&v)? diff --git a/mistralrs-core/src/models/phi3.rs b/mistralrs-core/src/models/phi3.rs index 3de891a6b..891e97084 100644 --- a/mistralrs-core/src/models/phi3.rs +++ b/mistralrs-core/src/models/phi3.rs @@ -33,6 +33,7 @@ pub struct Config { pub rope_scaling: Option, String>>>, pub max_position_embeddings: usize, pub use_flash_attn: bool, + pub sliding_window: Option, pub original_max_position_embeddings: usize, } @@ -52,6 +53,7 @@ struct Attention { head_dim: usize, rotary_emb: Arc, use_flash_attn: bool, + sliding_window: Option, } impl Attention { @@ -71,6 +73,7 @@ impl Attention { num_kv_groups: num_heads / num_kv_heads, head_dim, use_flash_attn: cfg.use_flash_attn, + sliding_window: cfg.sliding_window, }) } @@ -114,12 +117,40 @@ impl Attention { let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?; - let (k, v) = match &*kv_cache { - None => (k, v), - Some((prev_k, prev_v)) => { - let k = Tensor::cat(&[prev_k, &k], 2)?; - let v = Tensor::cat(&[prev_v, &v], 2)?; - (k, v) + let (k, v, attn_mask) = match kv_cache.clone() { + None => (k, v, attention_mask.cloned()), + Some((mut prev_k, mut prev_v)) => { + let mut mask = attention_mask.cloned(); + if let Some(sliding_window) = self.sliding_window { + let kv_seq_len = prev_k.dim(2)?; + if kv_seq_len > sliding_window { + prev_k = prev_k.narrow( + 2, + kv_seq_len - (sliding_window - 1), + sliding_window - 1, + )?; + prev_v = prev_v.narrow( + 2, + kv_seq_len - (sliding_window - 1), + sliding_window - 1, + )?; + if let Some(ref mut mask) = mask { + let mask_len = mask.dim(1)?; + *mask = mask.narrow( + 1, + mask_len - (sliding_window - 1), + sliding_window - 1, + )?; + *mask = Tensor::cat( + &[&*mask, &mask.narrow(1, mask_len - 1, 1)?.ones_like()?], + D::Minus1, + )?; + } + } + } + let k = Tensor::cat(&[prev_k, k], 2)?; + let v = Tensor::cat(&[prev_v, v], 2)?; + (k, v, mask) } }; *kv_cache = Some((k.clone(), v.clone())); @@ -138,9 +169,9 @@ impl Attention { let scale = 1f64 / f64::sqrt(self.head_dim as f64); let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?; - let attn_weights = match attention_mask { + let attn_weights = match attn_mask { None => attn_weights, - Some(mask) => attn_weights.broadcast_add(mask)?, + Some(mask) => attn_weights.broadcast_add(&mask)?, }; let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; attn_weights.matmul(&v)? @@ -277,6 +308,7 @@ pub struct Model { pub cache: Cache, pub max_seq_len: usize, mapper: Box, + sliding_window: Option, } impl Model { @@ -333,6 +365,7 @@ impl Model { cache: Cache::new(cfg.num_hidden_layers, false), max_seq_len: cfg.max_position_embeddings, mapper, + sliding_window: cfg.sliding_window, }) } @@ -341,9 +374,20 @@ impl Model { b_size: usize, tgt_len: usize, seqlen_offset: usize, + sliding_window: Option, ) -> Result { + // Sliding window mask + let sliding_window = sliding_window.unwrap_or(tgt_len + 1); let mask: Vec<_> = (0..tgt_len) - .flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. })) + .flat_map(|i| { + (0..tgt_len).map(move |j| { + if i < j || j + sliding_window < i { + f32::NEG_INFINITY + } else { + 0. + } + }) + }) .collect(); let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?; let mask = if seqlen_offset > 0 { @@ -384,8 +428,12 @@ impl Model { let attention_mask = if seq_len <= 1 { None } else { - let mask = - self.prepare_decoder_attention_mask(b_size, seq_len, past_key_values_length)?; + let mask = self.prepare_decoder_attention_mask( + b_size, + seq_len, + past_key_values_length, + self.sliding_window, + )?; Some(mask) }; let mut xs = self.embed_tokens.forward(input_ids)?; diff --git a/mistralrs-core/src/pipeline/loaders.rs b/mistralrs-core/src/pipeline/loaders.rs index 8b3750153..cfcc20565 100644 --- a/mistralrs-core/src/pipeline/loaders.rs +++ b/mistralrs-core/src/pipeline/loaders.rs @@ -520,6 +520,7 @@ struct Phi3BasicConfig { rope_scaling: Option>, max_position_embeddings: usize, original_max_position_embeddings: usize, + sliding_window: Option, } impl Phi3BasicConfig { @@ -545,6 +546,7 @@ impl Phi3BasicConfig { }), original_max_position_embeddings: basic_config.original_max_position_embeddings, use_flash_attn, + sliding_window: basic_config.sliding_window, }) } } diff --git a/mistralrs-core/src/xlora_models/mistral.rs b/mistralrs-core/src/xlora_models/mistral.rs index 222cc0c23..61bbf07fd 100644 --- a/mistralrs-core/src/xlora_models/mistral.rs +++ b/mistralrs-core/src/xlora_models/mistral.rs @@ -128,6 +128,7 @@ struct Attention { head_dim: usize, rotary_emb: Arc, use_flash_attn: bool, + sliding_window: Option, } impl Attention { @@ -195,6 +196,7 @@ impl Attention { head_dim, rotary_emb, use_flash_attn: cfg.use_flash_attn, + sliding_window: cfg.sliding_window, }) } @@ -261,12 +263,40 @@ impl Attention { .contiguous()?; } - let (k, v) = match &*kv_cache { - None => (k, v), - Some((prev_k, prev_v)) => { - let k = candle_nn::ops::kvconcat(prev_k, &k, 2)?; - let v = candle_nn::ops::kvconcat(prev_v, &v, 2)?; - (k, v) + let (k, v, attn_mask) = match kv_cache.clone() { + None => (k, v, attention_mask.cloned()), + Some((mut prev_k, mut prev_v)) => { + let mut mask = attention_mask.cloned(); + if let Some(sliding_window) = self.sliding_window { + let kv_seq_len = prev_k.dim(2)?; + if kv_seq_len > sliding_window { + prev_k = prev_k.narrow( + 2, + kv_seq_len - (sliding_window - 1), + sliding_window - 1, + )?; + prev_v = prev_v.narrow( + 2, + kv_seq_len - (sliding_window - 1), + sliding_window - 1, + )?; + if let Some(ref mut mask) = mask { + let mask_len = mask.dim(1)?; + *mask = mask.narrow( + 1, + mask_len - (sliding_window - 1), + sliding_window - 1, + )?; + *mask = Tensor::cat( + &[&*mask, &mask.narrow(1, mask_len - 1, 1)?.ones_like()?], + D::Minus1, + )?; + } + } + } + let k = candle_nn::ops::kvconcat(&prev_k, &k, 2)?; + let v = candle_nn::ops::kvconcat(&prev_v, &v, 2)?; + (k, v, mask) } }; *kv_cache = Some((k.clone(), v.clone())); @@ -285,9 +315,9 @@ impl Attention { let scale = 1f64 / f64::sqrt(self.head_dim as f64); let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?; - let attn_weights = match attention_mask { + let attn_weights = match attn_mask { None => attn_weights, - Some(mask) => attn_weights.broadcast_add(mask)?, + Some(mask) => attn_weights.broadcast_add(&mask)?, }; let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; attn_weights.matmul(&v)? diff --git a/mistralrs-core/src/xlora_models/mixtral.rs b/mistralrs-core/src/xlora_models/mixtral.rs index 9b594df56..b6826938d 100644 --- a/mistralrs-core/src/xlora_models/mixtral.rs +++ b/mistralrs-core/src/xlora_models/mixtral.rs @@ -32,6 +32,7 @@ struct Attention { head_dim: usize, rotary_emb: Arc, use_flash_attn: bool, + sliding_window: Option, } impl Attention { @@ -99,6 +100,7 @@ impl Attention { head_dim, rotary_emb, use_flash_attn: cfg.use_flash_attn, + sliding_window: Some(cfg.sliding_window), }) } @@ -165,12 +167,40 @@ impl Attention { .contiguous()?; } - let (k, v) = match &*kv_cache { - None => (k, v), - Some((prev_k, prev_v)) => { - let k = candle_nn::ops::kvconcat(prev_k, &k, 2)?; - let v = candle_nn::ops::kvconcat(prev_v, &v, 2)?; - (k, v) + let (k, v, attn_mask) = match kv_cache.clone() { + None => (k, v, attention_mask.cloned()), + Some((mut prev_k, mut prev_v)) => { + let mut mask = attention_mask.cloned(); + if let Some(sliding_window) = self.sliding_window { + let kv_seq_len = prev_k.dim(2)?; + if kv_seq_len > sliding_window { + prev_k = prev_k.narrow( + 2, + kv_seq_len - (sliding_window - 1), + sliding_window - 1, + )?; + prev_v = prev_v.narrow( + 2, + kv_seq_len - (sliding_window - 1), + sliding_window - 1, + )?; + if let Some(ref mut mask) = mask { + let mask_len = mask.dim(1)?; + *mask = mask.narrow( + 1, + mask_len - (sliding_window - 1), + sliding_window - 1, + )?; + *mask = Tensor::cat( + &[&*mask, &mask.narrow(1, mask_len - 1, 1)?.ones_like()?], + D::Minus1, + )?; + } + } + } + let k = candle_nn::ops::kvconcat(&prev_k, &k, 2)?; + let v = candle_nn::ops::kvconcat(&prev_v, &v, 2)?; + (k, v, mask) } }; *kv_cache = Some((k.clone(), v.clone())); @@ -189,9 +219,9 @@ impl Attention { let scale = 1f64 / f64::sqrt(self.head_dim as f64); let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?; - let attn_weights = match attention_mask { + let attn_weights = match attn_mask { None => attn_weights, - Some(mask) => attn_weights.broadcast_add(mask)?, + Some(mask) => attn_weights.broadcast_add(&mask)?, }; let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; attn_weights.matmul(&v)? diff --git a/mistralrs-core/src/xlora_models/phi3.rs b/mistralrs-core/src/xlora_models/phi3.rs index 13080c8d1..2456f2770 100644 --- a/mistralrs-core/src/xlora_models/phi3.rs +++ b/mistralrs-core/src/xlora_models/phi3.rs @@ -31,6 +31,7 @@ struct Attention { head_dim: usize, rotary_emb: Arc, use_flash_attn: bool, + sliding_window: Option, } impl Attention { @@ -77,6 +78,7 @@ impl Attention { num_kv_groups: num_heads / num_kv_heads, head_dim, use_flash_attn: cfg.use_flash_attn, + sliding_window: cfg.sliding_window, }) } @@ -129,12 +131,40 @@ impl Attention { let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?; - let (k, v) = match &*kv_cache { - None => (k, v), - Some((prev_k, prev_v)) => { - let k = Tensor::cat(&[prev_k, &k], 2)?; - let v = Tensor::cat(&[prev_v, &v], 2)?; - (k, v) + let (k, v, attn_mask) = match kv_cache.clone() { + None => (k, v, attention_mask.cloned()), + Some((mut prev_k, mut prev_v)) => { + let mut mask = attention_mask.cloned(); + if let Some(sliding_window) = self.sliding_window { + let kv_seq_len = prev_k.dim(2)?; + if kv_seq_len > sliding_window { + prev_k = prev_k.narrow( + 2, + kv_seq_len - (sliding_window - 1), + sliding_window - 1, + )?; + prev_v = prev_v.narrow( + 2, + kv_seq_len - (sliding_window - 1), + sliding_window - 1, + )?; + if let Some(ref mut mask) = mask { + let mask_len = mask.dim(1)?; + *mask = mask.narrow( + 1, + mask_len - (sliding_window - 1), + sliding_window - 1, + )?; + *mask = Tensor::cat( + &[&*mask, &mask.narrow(1, mask_len - 1, 1)?.ones_like()?], + D::Minus1, + )?; + } + } + } + let k = Tensor::cat(&[prev_k, k], 2)?; + let v = Tensor::cat(&[prev_v, v], 2)?; + (k, v, mask) } }; *kv_cache = Some((k.clone(), v.clone())); @@ -153,9 +183,9 @@ impl Attention { let scale = 1f64 / f64::sqrt(self.head_dim as f64); let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?; - let attn_weights = match attention_mask { + let attn_weights = match attn_mask { None => attn_weights, - Some(mask) => attn_weights.broadcast_add(mask)?, + Some(mask) => attn_weights.broadcast_add(&mask)?, }; let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; attn_weights.matmul(&v)? @@ -365,6 +395,7 @@ pub struct Model { pub max_seq_len: usize, mapper: Box, xlora_classifier: Option, + sliding_window: Option, } impl Model { @@ -451,6 +482,7 @@ impl Model { xlora_classifier: xlora_config.map(|xlora_config| { XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap() }), + sliding_window: cfg.sliding_window, }) } @@ -459,9 +491,20 @@ impl Model { b_size: usize, tgt_len: usize, seqlen_offset: usize, + sliding_window: Option, ) -> Result { + // Sliding window mask + let sliding_window = sliding_window.unwrap_or(tgt_len + 1); let mask: Vec<_> = (0..tgt_len) - .flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. })) + .flat_map(|i| { + (0..tgt_len).map(move |j| { + if i < j || j + sliding_window < i { + f32::NEG_INFINITY + } else { + 0. + } + }) + }) .collect(); let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?; let mask = if seqlen_offset > 0 { @@ -506,8 +549,12 @@ impl Model { let attention_mask = if seq_len <= 1 { None } else { - let mask = - self.prepare_decoder_attention_mask(b_size, seq_len, past_key_values_length)?; + let mask = self.prepare_decoder_attention_mask( + b_size, + seq_len, + past_key_values_length, + self.sliding_window, + )?; Some(mask) }; let mut xs = self.embed_tokens.forward(input_ids)?;