Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Commit

Permalink
remove bos token id and use float16 kv memory type
Browse files Browse the repository at this point in the history
  • Loading branch information
skirodev committed Jul 26, 2023
1 parent 871a5d8 commit 0afd18e
Showing 1 changed file with 17 additions and 11 deletions.
28 changes: 17 additions & 11 deletions crates/models/falcon/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,6 @@ impl KnownModel for Falcon {
let ctx0 = builder.ctx0.borrow();
let embd = builder.embd;
let mut input_layer = ctx0.op_get_rows(&self.tok_embeddings, embd);
let repeat_dummy = ctx0.new_tensor_3d(
input_layer.get_type(),
head_dim,
input_len + session_len,
n_head,
);

let f32_size = std::mem::size_of::<f32>();

Expand Down Expand Up @@ -254,7 +248,7 @@ impl KnownModel for Falcon {
// Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
let bigq = ctx0.op_permute(&qcur, (0, 2, 1, 3));

let mut bigk = ctx0.op_permute(
let bigk = ctx0.op_permute(
&ctx0.op_reshape_3d(
&ctx0.op_view_1d(
memory_k,
Expand All @@ -267,7 +261,19 @@ impl KnownModel for Falcon {
),
(0, 2, 1, 3),
);
bigk = ctx0.op_cont(&ctx0.op_repeat(&bigk, &repeat_dummy));

// K * Q
let big_kq = ctx0.op_mul_mat(&bigk, &bigq);

// KQ_scaled = KQ / sqrt(n_embd/n_head)
let big_kq_scaled = ctx0.op_scale_inplace(
&big_kq,
&ctx0.new_f32(1f32 / f32::sqrt(n_embd as f32 / n_head as f32)),
);

let big_kq_masked = ctx0.op_diag_mask_inf_inplace(&big_kq_scaled, session_len);

let big_kq_softmax = ctx0.op_soft_max_inplace(&big_kq_masked);

let mut bigv = ctx0.op_permute(
&ctx0.op_reshape_3d(
Expand All @@ -282,9 +288,9 @@ impl KnownModel for Falcon {
),
(0, 2, 1, 3),
);
bigv = ctx0.op_cont(&ctx0.op_transpose(&ctx0.op_repeat(&bigv, &repeat_dummy)));
bigv = ctx0.op_cont(&ctx0.op_transpose(&bigv));

let big_kqv = ctx0.op_flash_attn(&bigq, &bigk, &bigv, true);
let big_kqv = ctx0.op_mul_mat(&bigv, &big_kq_softmax);
// KQV_merged = KQV.permute(0, 2, 1, 3)
let big_kqv_merged = ctx0.op_permute(&big_kqv, (0, 2, 1, 3));

Expand Down Expand Up @@ -362,7 +368,7 @@ impl KnownModel for Falcon {
}

fn bot_token_id(&self) -> Option<TokenId> {
self.tokenizer.id(">>ABSTRACT<<".as_bytes())
None
}

fn eot_token_id(&self) -> TokenId {
Expand Down

0 comments on commit 0afd18e

Please sign in to comment.