Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

add Falcon 40B model support #368

Merged
merged 13 commits into from
Jul 27, 2023
14 changes: 14 additions & 0 deletions crates/ggml/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,20 @@ impl Context {
let tensor = unsafe { sys::ggml_gelu(self.as_ptr(), a.ptr.as_ptr()) };
self.new_tensor_raw(tensor)
}

/// flash attention.
pub fn op_flash_attn(&self, q: &Tensor, k: &Tensor, v: &Tensor, masked: bool) -> Tensor {
let tensor = unsafe {
sys::ggml_flash_attn(
self.as_ptr(),
q.ptr.as_ptr(),
k.ptr.as_ptr(),
v.ptr.as_ptr(),
masked,
)
};
self.new_tensor_raw(tensor)
}
}
// Public to this crate methods
impl Context {
Expand Down
110 changes: 75 additions & 35 deletions crates/models/falcon/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,29 @@ impl KnownModel for Falcon {
let lm_head = tl.load("lm_head.weight")?;

let mut layers = Vec::new();
// utilizing n_head_kv to determine the model version (parameters)
let Hyperparameters { n_head_kv, .. } = hyperparameters;
for i in 0..hyperparameters.n_layer {
let (input_layernorm_name, attention_norm_name) = if n_head_kv == 1 {
// falcon 7b
(format!("transformer.h.{i}.input_layernorm"), None)
} else {
// falcon 40b
(
format!("transformer.h.{i}.ln_mlp"),
Some(format!("transformer.h.{i}.ln_attn")),
)
};
let layer = Layer {
attention_norm: tl.load(&format!("transformer.h.{i}.input_layernorm.weight"))?,
attention_norm_b: tl.load(&format!("transformer.h.{i}.input_layernorm.bias"))?,
input_layernorm: tl.load(&format!("{}.weight", input_layernorm_name))?,
input_layernorm_b: tl.load(&format!("{}.bias", input_layernorm_name))?,
attention_norm: attention_norm_name
.as_ref()
.map(|path| tl.load(&format!("{}.weight", path)))
.transpose()?,
attention_norm_b: attention_norm_name
.map(|path| tl.load(&format!("{}.bias", path)))
.transpose()?,

query_key_value: tl.load(&format!(
"transformer.h.{i}.self_attention.query_key_value.weight"
Expand Down Expand Up @@ -118,6 +137,7 @@ impl KnownModel for Falcon {
let Hyperparameters {
n_embd,
n_head,
n_head_kv,
n_vocab,
n_layer,
..
Expand All @@ -130,12 +150,6 @@ impl KnownModel for Falcon {
let ctx0 = builder.ctx0.borrow();
let embd = builder.embd;
let mut input_layer = ctx0.op_get_rows(&self.tok_embeddings, embd);
let repeat_dummy = ctx0.new_tensor_3d(
input_layer.get_type(),
head_dim,
input_len + session_len,
n_head,
);

let f32_size = std::mem::size_of::<f32>();

Expand All @@ -155,21 +169,40 @@ impl KnownModel for Falcon {
ctx0.use_scratch(builder.get_scratch(0));

// self-attention
current = ctx0.op_norm(&input_layer);
current = ctx0.op_add(
layernorm_output = ctx0.op_norm(&input_layer);
layernorm_output = ctx0.op_add(
&ctx0.op_mul(
&ctx0.op_repeat(&self.layers[il].attention_norm, &current),
&current,
&ctx0.op_repeat(&self.layers[il].input_layernorm, &layernorm_output),
&layernorm_output,
),
&ctx0.op_repeat(&self.layers[il].attention_norm_b, &current),
&ctx0.op_repeat(&self.layers[il].input_layernorm_b, &layernorm_output),
);

layernorm_output = current.share();
if n_head_kv == 1 {
// Falcon-7B only
current = layernorm_output.share();
} else {
// Falcon-40B only
current = ctx0.op_norm(&input_layer);
current = ctx0.op_add(
&ctx0.op_mul(
&ctx0.op_repeat(
self.layers[il].attention_norm.as_ref().unwrap(),
&current,
),
&current,
),
&ctx0.op_repeat(
self.layers[il].attention_norm_b.as_ref().unwrap(),
&current,
),
);
}

// compute QKV
current = ctx0.op_mul_mat(&self.layers[il].query_key_value, &current);

let fused_qkv_row_nb = (n_embd + 2 * (n_embd / n_head)) * f32_size;
let fused_qkv_row_nb = head_dim * (n_head + 2 * n_head_kv) * f32_size;
LLukas22 marked this conversation as resolved.
Show resolved Hide resolved

let mut qcur = ctx0.op_view_3d(
&current,
Expand All @@ -180,16 +213,16 @@ impl KnownModel for Falcon {

let mut kcur = ctx0.op_view_3d(
&current,
(head_dim, 1, n),
(head_dim, n_head_kv, n),
(head_dim * f32_size, fused_qkv_row_nb),
n_embd * f32_size,
head_dim * n_head * f32_size,
);

let vcur = ctx0.op_view_3d(
&current,
(head_dim, 1, n),
(head_dim, n_head_kv, n),
(head_dim * f32_size, fused_qkv_row_nb),
(n_embd + head_dim) * f32_size,
head_dim * (n_head + n_head_kv) * f32_size,
);

// using mode = 2 for neox mode
Expand All @@ -200,13 +233,13 @@ impl KnownModel for Falcon {

let k = ctx0.op_view_1d(
memory_k,
n * head_dim,
(memory_k_size * head_dim) * (il * ctx_size + session_len),
n * n_head_kv * head_dim,
(memory_k_size * n_head_kv * head_dim) * (il * ctx_size + session_len),
);
let v = ctx0.op_view_1d(
memory_v,
n * head_dim,
(memory_v_size * head_dim) * (il * ctx_size + session_len),
n * n_head_kv * head_dim,
(memory_v_size * n_head_kv * head_dim) * (il * ctx_size + session_len),
);

gf.build_forward_expand(&ctx0.op_cpy(&kcur, &k));
Expand All @@ -215,21 +248,21 @@ impl KnownModel for Falcon {
// Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
let bigq = ctx0.op_permute(&qcur, (0, 2, 1, 3));

let mut bigk = ctx0.op_permute(
let bigk = ctx0.op_permute(
&ctx0.op_reshape_3d(
&ctx0.op_view_1d(
memory_k,
(session_len + n) * head_dim,
il * ctx_size * memory_k_size * head_dim,
(session_len + n) * n_head_kv * head_dim,
il * ctx_size * memory_k_size * n_head_kv * head_dim,
),
head_dim,
1,
n_head_kv,
session_len + n,
),
(0, 2, 1, 3),
);

// K * Q
bigk = ctx0.op_cont(&ctx0.op_repeat(&bigk, &repeat_dummy));
let big_kq = ctx0.op_mul_mat(&bigk, &bigq);

// KQ_scaled = KQ / sqrt(n_embd/n_head)
Expand All @@ -246,18 +279,17 @@ impl KnownModel for Falcon {
&ctx0.op_reshape_3d(
&ctx0.op_view_1d(
memory_v,
(session_len + n) * head_dim,
il * ctx_size * memory_v_size * head_dim,
(session_len + n) * n_head_kv * head_dim,
il * ctx_size * memory_v_size * n_head_kv * head_dim,
),
head_dim,
1,
n_head_kv,
session_len + n,
),
(0, 2, 1, 3),
);
bigv = ctx0.op_cont(&ctx0.op_transpose(&ctx0.op_repeat(&bigv, &repeat_dummy)));
bigv = ctx0.op_cont(&ctx0.op_transpose(&bigv));

// KQV = transpose(V) * KQ_soft_max
let big_kqv = ctx0.op_mul_mat(&bigv, &big_kq_softmax);
// KQV_merged = KQV.permute(0, 2, 1, 3)
let big_kqv_merged = ctx0.op_permute(&big_kqv, (0, 2, 1, 3));
Expand Down Expand Up @@ -361,6 +393,8 @@ pub struct Hyperparameters {
n_embd: usize,
/// n_heads
n_head: usize,
// Number of heads for key-value pairs
n_head_kv: usize,
/// Number of layers in the model
n_layer: usize,
/// file_type
Expand All @@ -373,6 +407,7 @@ impl llm_base::Hyperparameters for Hyperparameters {
n_vocab: util::read_i32(reader)?.try_into()?,
n_embd: util::read_i32(reader)?.try_into()?,
n_head: util::read_i32(reader)?.try_into()?,
n_head_kv: util::read_i32(reader)?.try_into()?,
n_layer: util::read_i32(reader)?.try_into()?,
file_type: util::read_filetype(reader)?,
};
Expand All @@ -384,6 +419,7 @@ impl llm_base::Hyperparameters for Hyperparameters {
util::write_i32(writer, self.n_vocab.try_into()?)?;
util::write_i32(writer, self.n_embd.try_into()?)?;
util::write_i32(writer, self.n_head.try_into()?)?;
util::write_i32(writer, self.n_head_kv.try_into()?)?;
util::write_i32(writer, self.n_layer.try_into()?)?;
util::write_i32(writer, self.file_type.into())?;
Ok(())
Expand All @@ -404,8 +440,12 @@ impl llm_base::Hyperparameters for Hyperparameters {

struct Layer {
// normalization
attention_norm: Tensor,
attention_norm_b: Tensor,
input_layernorm: Tensor,
input_layernorm_b: Tensor,

// Falcon-40B only
attention_norm: Option<Tensor>,
attention_norm_b: Option<Tensor>,

// attention
query_key_value: Tensor,
Expand Down