Skip to content

Commit

Permalink
Use it in quantized llama
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Oct 29, 2024
1 parent 49c7255 commit 10d357a
Showing 1 changed file with 20 additions and 14 deletions.
34 changes: 20 additions & 14 deletions candle-transformers/src/models/quantized_llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,21 +205,27 @@ impl LayerWeights {
};
self.kv_cache = Some((k.clone(), v.clone()));

// Support for MQA, useful for 70B models and mistral.
let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?;
let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?;

let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let att = match mask {
None => att,
Some(mask) => {
let mask = mask.broadcast_as(att.shape())?;
masked_fill(&att, &mask, &self.neg_inf)?
}
let y = if q.device().is_metal() && seq_len == 1 {
// SDPA will do MQA for us
candle_nn::ops::sdpa(&q, &k, &v, 1. / (self.head_dim as f32).sqrt(), 1.)?
} else {
// Support for MQA, useful for 70B models and mistral.
let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?;
let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?;

let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let att = match mask {
None => att,
Some(mask) => {
let mask = mask.broadcast_as(att.shape())?;
masked_fill(&att, &mask, &self.neg_inf)?
}
};
let att = candle_nn::ops::softmax_last_dim(&att)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
att.matmul(&v.contiguous()?)?
};
let att = candle_nn::ops::softmax_last_dim(&att)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?;

let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
let y = self.attention_wo.forward(&y)?;
Ok(y)
Expand Down

0 comments on commit 10d357a

Please sign in to comment.