From 2e30a2b0356c1f3d589e670523fbca0b342e1438 Mon Sep 17 00:00:00 2001 From: Jan Ploski Date: Mon, 26 Jun 2023 01:32:46 +0200 Subject: [PATCH] Eliminated need for ggml_repeat2 by using a modified version of https://github.com/ggerganov/ggml/pull/224 instead --- examples/falcon/main.cpp | 5 +---- src/ggml.c | 6 +++++- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/falcon/main.cpp b/examples/falcon/main.cpp index faad654b1..db6ded891 100644 --- a/examples/falcon/main.cpp +++ b/examples/falcon/main.cpp @@ -441,7 +441,6 @@ bool falcon_eval( // wte struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd); - struct ggml_tensor* repeat_dummy = ggml_new_tensor_3d(ctx0, inpL->type, head_dim, N + n_past, n_head); ggml_type wtype = GGML_TYPE_F32; const int sizeof_wtype = ggml_type_sizef(wtype); @@ -539,8 +538,6 @@ bool falcon_eval( // K * Q - K = ggml_cont(ctx0, ggml_repeat2(ctx0, K, repeat_dummy)); - struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); @@ -570,7 +567,7 @@ bool falcon_eval( head_dim, n_head_kv, n_past + N), 0, 2, 1, 3); - V = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_repeat2(ctx0, V, repeat_dummy))); + V = ggml_cont(ctx0, ggml_transpose(ctx0, V)); // KQV = transpose(V) * KQ_soft_max struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); diff --git a/src/ggml.c b/src/ggml.c index 3c16f2102..0d5ceb749 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -10735,7 +10735,11 @@ static void ggml_compute_forward_mul_mat_f32( const int64_t ir0 = (ir1/ne11)%(ne02*ne03); const int64_t i03 = (ir0/(ne02)); - const int64_t i02 = (ir0 - i03*ne02); + // Hack for "Falcon multi-query-attention key stutter" / alternative to ggml_repeat2. + // See https://github.com/ggerganov/llama.cpp/issues/1602#issuecomment-1606087470: + const int64_t i02 = (i12 / (ne12 / ne02)); + // Original from PR/224 (and also essential/correct for non-broadcast matmuls in Falcon) + // const int64_t i02 = (ir0 - i03*ne02); const int64_t i1 = i11; const int64_t i2 = i12;