Skip to content

Commit

Permalink
Merge pull request #244 from EricLBuehler/phi3_sliding_window
Browse files Browse the repository at this point in the history
Sliding window for phi3
  • Loading branch information
EricLBuehler authored Apr 29, 2024
2 parents 49cd334 + d3439af commit d442e02
Show file tree
Hide file tree
Showing 8 changed files with 282 additions and 67 deletions.
24 changes: 11 additions & 13 deletions mistralrs-core/src/layers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,22 +137,20 @@ impl PhiRotaryEmbedding {
let inv_freq_len = inv_freq_long.len();

let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(dtype)?
.to_dtype(DType::F32)?
.reshape((max_seq_len, 1))?;

// Calculate sin,cos for long
let inv_freq_long =
Tensor::from_vec(inv_freq_long, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len), dev)?;
let freqs_long = t.matmul(&inv_freq_long)?;
let long_sin = freqs_long.sin()?.mul(scaling_factor)?;
let long_cos = freqs_long.cos()?.mul(scaling_factor)?;
let long_sin = freqs_long.sin()?.mul(scaling_factor)?.to_dtype(dtype)?;
let long_cos = freqs_long.cos()?.mul(scaling_factor)?.to_dtype(dtype)?;

// Calculate sin,cos for short
let inv_freq_short =
Tensor::from_vec(inv_freq_short, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
let inv_freq_short = Tensor::from_vec(inv_freq_short, (1, inv_freq_len), dev)?;
let freqs_short = t.matmul(&inv_freq_short)?;
let short_sin = freqs_short.sin()?.mul(scaling_factor)?;
let short_cos = freqs_short.cos()?.mul(scaling_factor)?;
let short_sin = freqs_short.sin()?.mul(scaling_factor)?.to_dtype(dtype)?;
let short_cos = freqs_short.cos()?.mul(scaling_factor)?.to_dtype(dtype)?;

Ok(Self {
short_cos,
Expand All @@ -167,13 +165,13 @@ impl PhiRotaryEmbedding {
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(dtype)?
.to_dtype(DType::F32)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
let sin = freqs.sin()?;
let cos = freqs.cos()?;
let sin = freqs.sin()?.to_dtype(dtype)?;
let cos = freqs.cos()?.to_dtype(dtype)?;
Ok(Self {
short_cos: cos,
short_sin: sin,
Expand Down
46 changes: 38 additions & 8 deletions mistralrs-core/src/models/mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ struct Attention {
hidden_size: usize,
rotary_emb: Arc<RotaryEmbedding>,
use_flash_attn: bool,
sliding_window: Option<usize>,
}

impl Attention {
Expand All @@ -110,6 +111,7 @@ impl Attention {
hidden_size: hidden_sz,
rotary_emb,
use_flash_attn: cfg.use_flash_attn,
sliding_window: cfg.sliding_window,
})
}

Expand Down Expand Up @@ -157,12 +159,40 @@ impl Attention {
.contiguous()?;
}

let (k, v) = match &*kv_cache {
None => (k, v),
Some((prev_k, prev_v)) => {
let k = candle_nn::ops::kvconcat(prev_k, &k, 2)?;
let v = candle_nn::ops::kvconcat(prev_v, &v, 2)?;
(k, v)
let (k, v, attn_mask) = match kv_cache.clone() {
None => (k, v, attention_mask.cloned()),
Some((mut prev_k, mut prev_v)) => {
let mut mask = attention_mask.cloned();
if let Some(sliding_window) = self.sliding_window {
let kv_seq_len = prev_k.dim(2)?;
if kv_seq_len > sliding_window {
prev_k = prev_k.narrow(
2,
kv_seq_len - (sliding_window - 1),
sliding_window - 1,
)?;
prev_v = prev_v.narrow(
2,
kv_seq_len - (sliding_window - 1),
sliding_window - 1,
)?;
if let Some(ref mut mask) = mask {
let mask_len = mask.dim(1)?;
*mask = mask.narrow(
1,
mask_len - (sliding_window - 1),
sliding_window - 1,
)?;
*mask = Tensor::cat(
&[&*mask, &mask.narrow(1, mask_len - 1, 1)?.ones_like()?],
D::Minus1,
)?;
}
}
}
let k = candle_nn::ops::kvconcat(&prev_k, &k, 2)?;
let v = candle_nn::ops::kvconcat(&prev_v, &v, 2)?;
(k, v, mask)
}
};
*kv_cache = Some((k.clone(), v.clone()));
Expand All @@ -181,9 +211,9 @@ impl Attention {
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?;

let attn_weights = match attention_mask {
let attn_weights = match attn_mask {
None => attn_weights,
Some(mask) => attn_weights.broadcast_add(mask)?,
Some(mask) => attn_weights.broadcast_add(&mask)?,
};
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
attn_weights.matmul(&v)?
Expand Down
46 changes: 38 additions & 8 deletions mistralrs-core/src/models/mixtral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ struct Attention {
hidden_size: usize,
rotary_emb: Arc<RotaryEmbedding>,
use_flash_attn: bool,
sliding_window: Option<usize>,
}

impl Attention {
Expand All @@ -74,6 +75,7 @@ impl Attention {
hidden_size: hidden_sz,
rotary_emb,
use_flash_attn: cfg.use_flash_attn,
sliding_window: Some(cfg.sliding_window),
})
}

Expand Down Expand Up @@ -121,12 +123,40 @@ impl Attention {
.contiguous()?;
}

let (k, v) = match &*kv_cache {
None => (k, v),
Some((prev_k, prev_v)) => {
let k = candle_nn::ops::kvconcat(prev_k, &k, 2)?;
let v = candle_nn::ops::kvconcat(prev_v, &v, 2)?;
(k, v)
let (k, v, attn_mask) = match kv_cache.clone() {
None => (k, v, attention_mask.cloned()),
Some((mut prev_k, mut prev_v)) => {
let mut mask = attention_mask.cloned();
if let Some(sliding_window) = self.sliding_window {
let kv_seq_len = prev_k.dim(2)?;
if kv_seq_len > sliding_window {
prev_k = prev_k.narrow(
2,
kv_seq_len - (sliding_window - 1),
sliding_window - 1,
)?;
prev_v = prev_v.narrow(
2,
kv_seq_len - (sliding_window - 1),
sliding_window - 1,
)?;
if let Some(ref mut mask) = mask {
let mask_len = mask.dim(1)?;
*mask = mask.narrow(
1,
mask_len - (sliding_window - 1),
sliding_window - 1,
)?;
*mask = Tensor::cat(
&[&*mask, &mask.narrow(1, mask_len - 1, 1)?.ones_like()?],
D::Minus1,
)?;
}
}
}
let k = candle_nn::ops::kvconcat(&prev_k, &k, 2)?;
let v = candle_nn::ops::kvconcat(&prev_v, &v, 2)?;
(k, v, mask)
}
};
*kv_cache = Some((k.clone(), v.clone()));
Expand All @@ -145,9 +175,9 @@ impl Attention {
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?;

let attn_weights = match attention_mask {
let attn_weights = match attn_mask {
None => attn_weights,
Some(mask) => attn_weights.broadcast_add(mask)?,
Some(mask) => attn_weights.broadcast_add(&mask)?,
};
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
attn_weights.matmul(&v)?
Expand Down
70 changes: 59 additions & 11 deletions mistralrs-core/src/models/phi3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub struct Config {
pub rope_scaling: Option<HashMap<String, Either<Vec<f32>, String>>>,
pub max_position_embeddings: usize,
pub use_flash_attn: bool,
pub sliding_window: Option<usize>,
pub original_max_position_embeddings: usize,
}

Expand All @@ -52,6 +53,7 @@ struct Attention {
head_dim: usize,
rotary_emb: Arc<PhiRotaryEmbedding>,
use_flash_attn: bool,
sliding_window: Option<usize>,
}

impl Attention {
Expand All @@ -71,6 +73,7 @@ impl Attention {
num_kv_groups: num_heads / num_kv_heads,
head_dim,
use_flash_attn: cfg.use_flash_attn,
sliding_window: cfg.sliding_window,
})
}

Expand Down Expand Up @@ -114,12 +117,40 @@ impl Attention {

let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;

let (k, v) = match &*kv_cache {
None => (k, v),
Some((prev_k, prev_v)) => {
let k = Tensor::cat(&[prev_k, &k], 2)?;
let v = Tensor::cat(&[prev_v, &v], 2)?;
(k, v)
let (k, v, attn_mask) = match kv_cache.clone() {
None => (k, v, attention_mask.cloned()),
Some((mut prev_k, mut prev_v)) => {
let mut mask = attention_mask.cloned();
if let Some(sliding_window) = self.sliding_window {
let kv_seq_len = prev_k.dim(2)?;
if kv_seq_len > sliding_window {
prev_k = prev_k.narrow(
2,
kv_seq_len - (sliding_window - 1),
sliding_window - 1,
)?;
prev_v = prev_v.narrow(
2,
kv_seq_len - (sliding_window - 1),
sliding_window - 1,
)?;
if let Some(ref mut mask) = mask {
let mask_len = mask.dim(1)?;
*mask = mask.narrow(
1,
mask_len - (sliding_window - 1),
sliding_window - 1,
)?;
*mask = Tensor::cat(
&[&*mask, &mask.narrow(1, mask_len - 1, 1)?.ones_like()?],
D::Minus1,
)?;
}
}
}
let k = Tensor::cat(&[prev_k, k], 2)?;
let v = Tensor::cat(&[prev_v, v], 2)?;
(k, v, mask)
}
};
*kv_cache = Some((k.clone(), v.clone()));
Expand All @@ -138,9 +169,9 @@ impl Attention {
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?;

let attn_weights = match attention_mask {
let attn_weights = match attn_mask {
None => attn_weights,
Some(mask) => attn_weights.broadcast_add(mask)?,
Some(mask) => attn_weights.broadcast_add(&mask)?,
};
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
attn_weights.matmul(&v)?
Expand Down Expand Up @@ -277,6 +308,7 @@ pub struct Model {
pub cache: Cache,
pub max_seq_len: usize,
mapper: Box<dyn DeviceMapper + Send + Sync>,
sliding_window: Option<usize>,
}

impl Model {
Expand Down Expand Up @@ -333,6 +365,7 @@ impl Model {
cache: Cache::new(cfg.num_hidden_layers, false),
max_seq_len: cfg.max_position_embeddings,
mapper,
sliding_window: cfg.sliding_window,
})
}

Expand All @@ -341,9 +374,20 @@ impl Model {
b_size: usize,
tgt_len: usize,
seqlen_offset: usize,
sliding_window: Option<usize>,
) -> Result<Tensor> {
// Sliding window mask
let sliding_window = sliding_window.unwrap_or(tgt_len + 1);
let mask: Vec<_> = (0..tgt_len)
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
.flat_map(|i| {
(0..tgt_len).map(move |j| {
if i < j || j + sliding_window < i {
f32::NEG_INFINITY
} else {
0.
}
})
})
.collect();
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
let mask = if seqlen_offset > 0 {
Expand Down Expand Up @@ -384,8 +428,12 @@ impl Model {
let attention_mask = if seq_len <= 1 {
None
} else {
let mask =
self.prepare_decoder_attention_mask(b_size, seq_len, past_key_values_length)?;
let mask = self.prepare_decoder_attention_mask(
b_size,
seq_len,
past_key_values_length,
self.sliding_window,
)?;
Some(mask)
};
let mut xs = self.embed_tokens.forward(input_ids)?;
Expand Down
2 changes: 2 additions & 0 deletions mistralrs-core/src/pipeline/loaders.rs
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,7 @@ struct Phi3BasicConfig {
rope_scaling: Option<HashMap<String, RopeScaling>>,
max_position_embeddings: usize,
original_max_position_embeddings: usize,
sliding_window: Option<usize>,
}

impl Phi3BasicConfig {
Expand All @@ -545,6 +546,7 @@ impl Phi3BasicConfig {
}),
original_max_position_embeddings: basic_config.original_max_position_embeddings,
use_flash_attn,
sliding_window: basic_config.sliding_window,
})
}
}
Expand Down
Loading

0 comments on commit d442e02

Please sign in to comment.