Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

add Falcon 40B model support #368

Merged
merged 13 commits into from
Jul 27, 2023
14 changes: 14 additions & 0 deletions crates/ggml/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,20 @@ impl Context {
let tensor = unsafe { sys::ggml_gelu(self.ptr.as_ptr(), a.ptr.as_ptr()) };
self.new_tensor_raw(tensor)
}

/// flash attention.
pub fn op_flash_attn(&self, q: &Tensor, k: &Tensor, v: &Tensor, masked: bool) -> Tensor {
let tensor = unsafe {
sys::ggml_flash_attn(
self.ptr.as_ptr(),
q.ptr.as_ptr(),
k.ptr.as_ptr(),
v.ptr.as_ptr(),
masked,
)
};
LLukas22 marked this conversation as resolved.
Show resolved Hide resolved
self.new_tensor_raw(tensor)
}
}

impl Drop for Context {
Expand Down
2 changes: 1 addition & 1 deletion crates/llm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ default = ["models", "tokenizers-remote"]

tokenizers-remote = ["llm-base/tokenizers-remote"]

models = ["llama", "gpt2", "gptj", "bloom", "gptneox", "mpt"]
models = ["llama", "gpt2", "gptj", "bloom", "gptneox", "mpt", "falcon"]
LLukas22 marked this conversation as resolved.
Show resolved Hide resolved
llama = ["dep:llm-llama"]
gpt2 = ["dep:llm-gpt2"]
gptj = ["dep:llm-gptj"]
Expand Down
98 changes: 65 additions & 33 deletions crates/models/falcon/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,29 @@ impl KnownModel for Falcon {
let lm_head = tl.load("lm_head.weight")?;

let mut layers = Vec::new();
// utilizing n_head_kv to determine the model version (parameters)
let Hyperparameters { n_head_kv, .. } = hyperparameters;
for i in 0..hyperparameters.n_layer {
let (input_layernorm_name, attention_norm_name) = if n_head_kv == 1 {
// falcon 7b
(format!("transformer.h.{i}.input_layernorm"), None)
} else {
// falcon 40b
(
format!("transformer.h.{i}.ln_mlp"),
Some(format!("transformer.h.{i}.ln_attn")),
)
};
let layer = Layer {
attention_norm: tl.load(&format!("transformer.h.{i}.input_layernorm.weight"))?,
attention_norm_b: tl.load(&format!("transformer.h.{i}.input_layernorm.bias"))?,
input_layernorm: tl.load(&format!("{}.weight", input_layernorm_name))?,
input_layernorm_b: tl.load(&format!("{}.bias", input_layernorm_name))?,
attention_norm: attention_norm_name
.as_ref()
.map(|path| tl.load(&format!("{}.bias", path)))
.transpose()?,
attention_norm_b: attention_norm_name
.map(|path| tl.load(&format!("{}.bias", path)))
.transpose()?,

query_key_value: tl.load(&format!(
"transformer.h.{i}.self_attention.query_key_value.weight"
Expand Down Expand Up @@ -123,6 +142,7 @@ impl KnownModel for Falcon {
let Hyperparameters {
n_embd,
n_head,
n_head_kv,
n_vocab,
n_layer,
..
Expand Down Expand Up @@ -163,18 +183,35 @@ impl KnownModel for Falcon {
current = ctx0.op_norm(&input_layer);
current = ctx0.op_add(
&ctx0.op_mul(
&ctx0.op_repeat(&self.layers[il].attention_norm, &current),
&ctx0.op_repeat(&self.layers[il].input_layernorm, &current),
&current,
),
&ctx0.op_repeat(&self.layers[il].attention_norm_b, &current),
&ctx0.op_repeat(&self.layers[il].input_layernorm_b, &current),
);

layernorm_output = current.share();

// Falcon-40B only
if n_head_kv != 1 {
current = ctx0.op_add(
&ctx0.op_mul(
&ctx0.op_repeat(
self.layers[il].attention_norm.as_ref().unwrap(),
&current,
),
&current,
),
&ctx0.op_repeat(
self.layers[il].attention_norm_b.as_ref().unwrap(),
&current,
),
);
}

// compute QKV
current = ctx0.op_mul_mat(&self.layers[il].query_key_value, &current);

let fused_qkv_row_nb = (n_embd + 2 * (n_embd / n_head)) * f32_size;
let fused_qkv_row_nb = head_dim * (n_head + 2 * n_head_kv) * f32_size;
LLukas22 marked this conversation as resolved.
Show resolved Hide resolved

let mut qcur = ctx0.op_view_3d(
&current,
Expand All @@ -185,16 +222,16 @@ impl KnownModel for Falcon {

let mut kcur = ctx0.op_view_3d(
&current,
(head_dim, 1, n),
(head_dim, n_head_kv, n),
(head_dim * f32_size, fused_qkv_row_nb),
n_embd * f32_size,
head_dim * n_head * f32_size,
);

let vcur = ctx0.op_view_3d(
&current,
(head_dim, 1, n),
(head_dim, n_head_kv, n),
(head_dim * f32_size, fused_qkv_row_nb),
(n_embd + head_dim) * f32_size,
head_dim * (n_head + n_head_kv) * f32_size,
);

// using mode = 2 for neox mode
Expand All @@ -205,13 +242,13 @@ impl KnownModel for Falcon {

let k = ctx0.op_view_1d(
memory_k,
n * head_dim,
(memory_k_size * head_dim) * (il * ctx_size + session_len),
n * n_head_kv * head_dim,
(memory_k_size * n_head_kv * head_dim) * (il * ctx_size + session_len),
);
let v = ctx0.op_view_1d(
memory_v,
n * head_dim,
(memory_v_size * head_dim) * (il * ctx_size + session_len),
n * n_head_kv * head_dim,
(memory_v_size * n_head_kv * head_dim) * (il * ctx_size + session_len),
);

gf.build_forward_expand(&ctx0.op_cpy(&kcur, &k));
Expand All @@ -224,28 +261,16 @@ impl KnownModel for Falcon {
&ctx0.op_reshape_3d(
&ctx0.op_view_1d(
memory_k,
(session_len + n) * head_dim,
il * ctx_size * memory_k_size * head_dim,
(session_len + n) * n_head_kv * head_dim,
il * ctx_size * memory_k_size * n_head_kv * head_dim,
),
head_dim,
1,
n_head_kv,
session_len + n,
),
(0, 2, 1, 3),
);
// K * Q
bigk = ctx0.op_cont(&ctx0.op_repeat(&bigk, &repeat_dummy));
let big_kq = ctx0.op_mul_mat(&bigk, &bigq);

// KQ_scaled = KQ / sqrt(n_embd/n_head)
let big_kq_scaled = ctx0.op_scale_inplace(
&big_kq,
&ctx0.new_f32(1f32 / f32::sqrt(n_embd as f32 / n_head as f32)),
);

let big_kq_masked = ctx0.op_diag_mask_inf_inplace(&big_kq_scaled, session_len);

let big_kq_softmax = ctx0.op_soft_max_inplace(&big_kq_masked);

let mut bigv = ctx0.op_permute(
&ctx0.op_reshape_3d(
Expand All @@ -262,8 +287,7 @@ impl KnownModel for Falcon {
);
bigv = ctx0.op_cont(&ctx0.op_transpose(&ctx0.op_repeat(&bigv, &repeat_dummy)));

// KQV = transpose(V) * KQ_soft_max
let big_kqv = ctx0.op_mul_mat(&bigv, &big_kq_softmax);
let big_kqv = ctx0.op_flash_attn(&bigq, &bigk, &bigv, true);
// KQV_merged = KQV.permute(0, 2, 1, 3)
let big_kqv_merged = ctx0.op_permute(&big_kqv, (0, 2, 1, 3));

Expand Down Expand Up @@ -341,7 +365,7 @@ impl KnownModel for Falcon {
}

fn bot_token_id(&self) -> Option<TokenId> {
None
self.tokenizer.id(">>ABSTRACT<<".as_bytes())
}

fn eot_token_id(&self) -> TokenId {
Expand All @@ -366,6 +390,8 @@ pub struct Hyperparameters {
n_embd: usize,
/// n_heads
n_head: usize,
// Number of heads for key-value pairs
n_head_kv: usize,
/// Number of layers in the model
n_layer: usize,
/// file_type
Expand All @@ -378,6 +404,7 @@ impl llm_base::Hyperparameters for Hyperparameters {
n_vocab: util::read_i32(reader)?.try_into()?,
n_embd: util::read_i32(reader)?.try_into()?,
n_head: util::read_i32(reader)?.try_into()?,
n_head_kv: util::read_i32(reader)?.try_into()?,
n_layer: util::read_i32(reader)?.try_into()?,
file_type: util::read_filetype(reader)?,
};
Expand All @@ -389,6 +416,7 @@ impl llm_base::Hyperparameters for Hyperparameters {
util::write_i32(writer, self.n_vocab.try_into()?)?;
util::write_i32(writer, self.n_embd.try_into()?)?;
util::write_i32(writer, self.n_head.try_into()?)?;
util::write_i32(writer, self.n_head_kv.try_into()?)?;
util::write_i32(writer, self.n_layer.try_into()?)?;
util::write_i32(writer, self.file_type.into())?;
Ok(())
Expand All @@ -409,8 +437,12 @@ impl llm_base::Hyperparameters for Hyperparameters {

struct Layer {
// normalization
attention_norm: Tensor,
attention_norm_b: Tensor,
input_layernorm: Tensor,
input_layernorm_b: Tensor,

// Falcon-40B only
attention_norm: Option<Tensor>,
attention_norm_b: Option<Tensor>,

// attention
query_key_value: Tensor,
Expand Down