diff --git a/crates/ggml/src/context.rs b/crates/ggml/src/context.rs index 472b58c1..33dbb00e 100644 --- a/crates/ggml/src/context.rs +++ b/crates/ggml/src/context.rs @@ -561,6 +561,20 @@ impl Context { let tensor = unsafe { sys::ggml_gelu(self.as_ptr(), a.ptr.as_ptr()) }; self.new_tensor_raw(tensor) } + + /// flash attention. + pub fn op_flash_attn(&self, q: &Tensor, k: &Tensor, v: &Tensor, masked: bool) -> Tensor { + let tensor = unsafe { + sys::ggml_flash_attn( + self.as_ptr(), + q.ptr.as_ptr(), + k.ptr.as_ptr(), + v.ptr.as_ptr(), + masked, + ) + }; + self.new_tensor_raw(tensor) + } } // Public to this crate methods impl Context { diff --git a/crates/models/falcon/src/lib.rs b/crates/models/falcon/src/lib.rs index 1c4c11c4..b647b361 100644 --- a/crates/models/falcon/src/lib.rs +++ b/crates/models/falcon/src/lib.rs @@ -63,10 +63,29 @@ impl KnownModel for Falcon { let lm_head = tl.load("lm_head.weight")?; let mut layers = Vec::new(); + // utilizing n_head_kv to determine the model version (parameters) + let Hyperparameters { n_head_kv, .. } = hyperparameters; for i in 0..hyperparameters.n_layer { + let (input_layernorm_name, attention_norm_name) = if n_head_kv == 1 { + // falcon 7b + (format!("transformer.h.{i}.input_layernorm"), None) + } else { + // falcon 40b + ( + format!("transformer.h.{i}.ln_mlp"), + Some(format!("transformer.h.{i}.ln_attn")), + ) + }; let layer = Layer { - attention_norm: tl.load(&format!("transformer.h.{i}.input_layernorm.weight"))?, - attention_norm_b: tl.load(&format!("transformer.h.{i}.input_layernorm.bias"))?, + input_layernorm: tl.load(&format!("{}.weight", input_layernorm_name))?, + input_layernorm_b: tl.load(&format!("{}.bias", input_layernorm_name))?, + attention_norm: attention_norm_name + .as_ref() + .map(|path| tl.load(&format!("{}.weight", path))) + .transpose()?, + attention_norm_b: attention_norm_name + .map(|path| tl.load(&format!("{}.bias", path))) + .transpose()?, query_key_value: tl.load(&format!( "transformer.h.{i}.self_attention.query_key_value.weight" @@ -118,6 +137,7 @@ impl KnownModel for Falcon { let Hyperparameters { n_embd, n_head, + n_head_kv, n_vocab, n_layer, .. @@ -130,12 +150,6 @@ impl KnownModel for Falcon { let ctx0 = builder.ctx0.borrow(); let embd = builder.embd; let mut input_layer = ctx0.op_get_rows(&self.tok_embeddings, embd); - let repeat_dummy = ctx0.new_tensor_3d( - input_layer.get_type(), - head_dim, - input_len + session_len, - n_head, - ); let f32_size = std::mem::size_of::(); @@ -155,21 +169,40 @@ impl KnownModel for Falcon { ctx0.use_scratch(builder.get_scratch(0)); // self-attention - current = ctx0.op_norm(&input_layer); - current = ctx0.op_add( + layernorm_output = ctx0.op_norm(&input_layer); + layernorm_output = ctx0.op_add( &ctx0.op_mul( - &ctx0.op_repeat(&self.layers[il].attention_norm, ¤t), - ¤t, + &ctx0.op_repeat(&self.layers[il].input_layernorm, &layernorm_output), + &layernorm_output, ), - &ctx0.op_repeat(&self.layers[il].attention_norm_b, ¤t), + &ctx0.op_repeat(&self.layers[il].input_layernorm_b, &layernorm_output), ); - layernorm_output = current.share(); + if n_head_kv == 1 { + // Falcon-7B only + current = layernorm_output.share(); + } else { + // Falcon-40B only + current = ctx0.op_norm(&input_layer); + current = ctx0.op_add( + &ctx0.op_mul( + &ctx0.op_repeat( + self.layers[il].attention_norm.as_ref().unwrap(), + ¤t, + ), + ¤t, + ), + &ctx0.op_repeat( + self.layers[il].attention_norm_b.as_ref().unwrap(), + ¤t, + ), + ); + } // compute QKV current = ctx0.op_mul_mat(&self.layers[il].query_key_value, ¤t); - let fused_qkv_row_nb = (n_embd + 2 * (n_embd / n_head)) * f32_size; + let fused_qkv_row_nb = head_dim * (n_head + 2 * n_head_kv) * f32_size; let mut qcur = ctx0.op_view_3d( ¤t, @@ -180,16 +213,16 @@ impl KnownModel for Falcon { let mut kcur = ctx0.op_view_3d( ¤t, - (head_dim, 1, n), + (head_dim, n_head_kv, n), (head_dim * f32_size, fused_qkv_row_nb), - n_embd * f32_size, + head_dim * n_head * f32_size, ); let vcur = ctx0.op_view_3d( ¤t, - (head_dim, 1, n), + (head_dim, n_head_kv, n), (head_dim * f32_size, fused_qkv_row_nb), - (n_embd + head_dim) * f32_size, + head_dim * (n_head + n_head_kv) * f32_size, ); // using mode = 2 for neox mode @@ -200,13 +233,13 @@ impl KnownModel for Falcon { let k = ctx0.op_view_1d( memory_k, - n * head_dim, - (memory_k_size * head_dim) * (il * ctx_size + session_len), + n * n_head_kv * head_dim, + (memory_k_size * n_head_kv * head_dim) * (il * ctx_size + session_len), ); let v = ctx0.op_view_1d( memory_v, - n * head_dim, - (memory_v_size * head_dim) * (il * ctx_size + session_len), + n * n_head_kv * head_dim, + (memory_v_size * n_head_kv * head_dim) * (il * ctx_size + session_len), ); gf.build_forward_expand(&ctx0.op_cpy(&kcur, &k)); @@ -215,21 +248,21 @@ impl KnownModel for Falcon { // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3) let bigq = ctx0.op_permute(&qcur, (0, 2, 1, 3)); - let mut bigk = ctx0.op_permute( + let bigk = ctx0.op_permute( &ctx0.op_reshape_3d( &ctx0.op_view_1d( memory_k, - (session_len + n) * head_dim, - il * ctx_size * memory_k_size * head_dim, + (session_len + n) * n_head_kv * head_dim, + il * ctx_size * memory_k_size * n_head_kv * head_dim, ), head_dim, - 1, + n_head_kv, session_len + n, ), (0, 2, 1, 3), ); + // K * Q - bigk = ctx0.op_cont(&ctx0.op_repeat(&bigk, &repeat_dummy)); let big_kq = ctx0.op_mul_mat(&bigk, &bigq); // KQ_scaled = KQ / sqrt(n_embd/n_head) @@ -246,18 +279,17 @@ impl KnownModel for Falcon { &ctx0.op_reshape_3d( &ctx0.op_view_1d( memory_v, - (session_len + n) * head_dim, - il * ctx_size * memory_v_size * head_dim, + (session_len + n) * n_head_kv * head_dim, + il * ctx_size * memory_v_size * n_head_kv * head_dim, ), head_dim, - 1, + n_head_kv, session_len + n, ), (0, 2, 1, 3), ); - bigv = ctx0.op_cont(&ctx0.op_transpose(&ctx0.op_repeat(&bigv, &repeat_dummy))); + bigv = ctx0.op_cont(&ctx0.op_transpose(&bigv)); - // KQV = transpose(V) * KQ_soft_max let big_kqv = ctx0.op_mul_mat(&bigv, &big_kq_softmax); // KQV_merged = KQV.permute(0, 2, 1, 3) let big_kqv_merged = ctx0.op_permute(&big_kqv, (0, 2, 1, 3)); @@ -361,6 +393,8 @@ pub struct Hyperparameters { n_embd: usize, /// n_heads n_head: usize, + // Number of heads for key-value pairs + n_head_kv: usize, /// Number of layers in the model n_layer: usize, /// file_type @@ -373,6 +407,7 @@ impl llm_base::Hyperparameters for Hyperparameters { n_vocab: util::read_i32(reader)?.try_into()?, n_embd: util::read_i32(reader)?.try_into()?, n_head: util::read_i32(reader)?.try_into()?, + n_head_kv: util::read_i32(reader)?.try_into()?, n_layer: util::read_i32(reader)?.try_into()?, file_type: util::read_filetype(reader)?, }; @@ -384,6 +419,7 @@ impl llm_base::Hyperparameters for Hyperparameters { util::write_i32(writer, self.n_vocab.try_into()?)?; util::write_i32(writer, self.n_embd.try_into()?)?; util::write_i32(writer, self.n_head.try_into()?)?; + util::write_i32(writer, self.n_head_kv.try_into()?)?; util::write_i32(writer, self.n_layer.try_into()?)?; util::write_i32(writer, self.file_type.into())?; Ok(()) @@ -404,8 +440,12 @@ impl llm_base::Hyperparameters for Hyperparameters { struct Layer { // normalization - attention_norm: Tensor, - attention_norm_b: Tensor, + input_layernorm: Tensor, + input_layernorm_b: Tensor, + + // Falcon-40B only + attention_norm: Option, + attention_norm_b: Option, // attention query_key_value: Tensor,