From 3c7c325b9867a30637529b0328fbc73b4e527004 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 22 Aug 2023 22:31:39 +0300 Subject: [PATCH] falcon : CPU inference working --- llama.cpp | 299 +++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 286 insertions(+), 13 deletions(-) diff --git a/llama.cpp b/llama.cpp index e19f46a88817f..c942b0727e407 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1031,8 +1031,6 @@ struct llama_context { // key + value cache for the self attention struct llama_kv_cache kv_self; - size_t mem_per_token = 0; - // decode output (2-dimensional array: [n_tokens][n_vocab]) std::vector logits; bool logits_all = false; @@ -2014,7 +2012,7 @@ static bool llama_model_load( return true; } -static struct ggml_cgraph * llama_build_graph( +static struct ggml_cgraph * llm_build_llama( llama_context & lctx, const llama_token * tokens, const float * embd, @@ -2048,8 +2046,7 @@ static struct ggml_cgraph * llama_build_graph( const int n_gpu_layers = model.n_gpu_layers; - auto & mem_per_token = lctx.mem_per_token; - auto & buf_compute = lctx.buf_compute; + auto & buf_compute = lctx.buf_compute; struct ggml_init_params params = { /*.mem_size =*/ buf_compute.size, @@ -2340,20 +2337,296 @@ static struct ggml_cgraph * llama_build_graph( cur = ggml_mul_mat(ctx0, model.output, cur); ggml_set_name(cur, "result_output"); - // logits -> probs - //cur = ggml_soft_max_inplace(ctx0, cur); - ggml_build_forward_expand(gf, cur); - if (mem_per_token == 0) { - mem_per_token = ggml_used_mem(ctx0)/N; + ggml_free(ctx0); + + return gf; +} + +static struct ggml_cgraph * llm_build_falcon( + llama_context & lctx, + const llama_token * tokens, + const float * embd, + int n_tokens, + int n_past) { + + GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT + + const int N = n_tokens; + + const auto & model = lctx.model; + const auto & hparams = model.hparams; + + const auto & kv_self = lctx.kv_self; + + GGML_ASSERT(!!kv_self.ctx); + + const int64_t n_embd = hparams.n_embd; + const int64_t n_layer = hparams.n_layer; + const int64_t n_ctx = hparams.n_ctx; + const int64_t n_head = hparams.n_head; + const int64_t n_head_kv = hparams.n_head_kv; + const int64_t n_embd_head = hparams.n_embd_head(); + //const int64_t n_embd_gqa = hparams.n_embd_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_rot); + + auto & buf_compute = lctx.buf_compute; + + struct ggml_init_params params = { + /*.mem_size =*/ buf_compute.size, + /*.mem_buffer =*/ buf_compute.data, + /*.no_alloc =*/ false, + }; + + params.no_alloc = true; + + struct ggml_context * ctx0 = ggml_init(params); + + ggml_cgraph * gf = ggml_new_graph(ctx0); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + if (tokens) { + struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + + ggml_allocr_alloc(lctx.alloc, inp_tokens); + if (!ggml_allocr_is_measure(lctx.alloc)) { + memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens)); + } + ggml_set_name(inp_tokens, "inp_tokens"); + + inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); + } else { +#ifdef GGML_USE_MPI + GGML_ASSERT(false && "not implemented"); +#endif + + inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N); + + ggml_allocr_alloc(lctx.alloc, inpL); + if (!ggml_allocr_is_measure(lctx.alloc)) { + memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL)); + } + } + + struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + ggml_allocr_alloc(lctx.alloc, KQ_scale); + if (!ggml_allocr_is_measure(lctx.alloc)) { + ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head)); + } + ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * cur; + struct ggml_tensor * layernorm_output; + + // self-attention + { + layernorm_output = ggml_norm(ctx0, inpL); + + layernorm_output = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, model.layers[il].attn_norm, layernorm_output), + layernorm_output), + ggml_repeat(ctx0, model.layers[il].attn_norm_b, layernorm_output)); + + if ( hparams.n_head_kv == 8 ) { // Falcon-40B + cur = ggml_norm(ctx0, inpL); + + cur = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, model.layers[il].attn_norm_2, cur), + cur), + ggml_repeat(ctx0, model.layers[il].attn_norm_2_b, cur)); + } + else { // Falcon 7B + cur = layernorm_output; + } + + // compute QKV + + cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + + // Note that the strides for Kcur, Vcur are set up so that the + // resulting views are misaligned with the tensor's storage + // (by applying the K/V offset we shift the tensor's original + // view to stick out behind the viewed QKV tensor's allocated + // memory, so to say). This is ok because no actual accesses + // happen to that out-of-range memory, but it can require some + // trickery when trying to accurately dump these views for + // debugging. + + struct ggml_tensor * Qcur = ggml_view_3d( + ctx0, cur, n_embd_head, n_head, N, + n_embd_head * ggml_type_size(GGML_TYPE_F32), + n_embd_head * (n_head + 2 * n_head_kv) * ggml_type_size(GGML_TYPE_F32), + 0); + + struct ggml_tensor * Kcur = ggml_view_3d( + ctx0, cur, n_embd_head, n_head_kv, N, + n_embd_head * ggml_type_size(GGML_TYPE_F32), + n_embd_head * (n_head + 2 * n_head_kv) * ggml_type_size(GGML_TYPE_F32), + n_embd_head * n_head * ggml_type_size(GGML_TYPE_F32)); + + struct ggml_tensor * Vcur = ggml_view_3d( + ctx0, cur, n_embd_head, n_head_kv, N, + n_embd_head * ggml_type_size(GGML_TYPE_F32), + n_embd_head * (n_head + 2 * n_head_kv) * ggml_type_size(GGML_TYPE_F32), + n_embd_head * (n_head + n_head_kv) * ggml_type_size(GGML_TYPE_F32)); + + // using mode = 2 for neox mode + Qcur = ggml_rope_inplace(ctx0, Qcur, n_past, n_embd_head, 2, 0); + Kcur = ggml_rope_inplace(ctx0, Kcur, n_past, n_embd_head, 2, 0); + + // store key and value to memory + { + struct ggml_tensor* k = ggml_view_1d( + ctx0, kv_self.k, N * n_head_kv * n_embd_head, + (ggml_element_size(kv_self.k) * n_head_kv * n_embd_head) * + (il * n_ctx + n_past)); + struct ggml_tensor* v = ggml_view_1d( + ctx0, kv_self.v, N * n_head_kv * n_embd_head, + (ggml_element_size(kv_self.v) * n_head_kv * n_embd_head) * + (il * n_ctx + n_past)); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); + } + + struct ggml_tensor * K = ggml_permute( + ctx0, + ggml_reshape_3d( + ctx0, + ggml_view_1d(ctx0, kv_self.k, (n_past + N) * n_head_kv * n_embd_head, + il * n_ctx * + ggml_element_size(kv_self.k) * + n_head_kv * + n_embd_head), + n_embd_head, n_head_kv, n_past + N), + 0, 2, 1, 3); + + // K * Q + +// K = ggml_cont(ctx0, ggml_repeat2(ctx0, K, repeat_dummy)); + + struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + + // KQ_scaled = KQ / sqrt(n_embd/n_head) + struct ggml_tensor * KQ_scaled = + ggml_scale_inplace(ctx0, + KQ, + ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd_head))) + ); + + // KQ_masked = mask_past(KQ_scaled) + struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + + // KQ = soft_max(KQ_masked) + struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); + + // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() + struct ggml_tensor* V = ggml_permute( + ctx0, + ggml_reshape_3d( + ctx0, + ggml_view_1d(ctx0, kv_self.v, (n_past + N) * n_head_kv * n_embd_head, + il * n_ctx * + ggml_element_size(kv_self.v) * + n_head_kv * + n_embd_head), + n_embd_head, n_head_kv, n_past + N), + 0, 2, 1, 3); + +// V = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_repeat2(ctx0, V, repeat_dummy))); + V = ggml_cont(ctx0, ggml_transpose(ctx0, V)); + + // KQV = transpose(V) * KQ_soft_max + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + + // KQV_merged = KQV.permute(0, 2, 1, 3) + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + // cur = KQV_merged.contiguous().view(n_embd, N) + cur = ggml_cpy(ctx0, + KQV_merged, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + + // projection + { + cur = ggml_mul_mat(ctx0, + model.layers[il].wo, + cur); + } + } + + struct ggml_tensor* inpFF = layernorm_output; + struct ggml_tensor* attn_out = ggml_cpy( + ctx0, cur, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + + { + cur = ggml_mul_mat(ctx0, model.layers[il].w3, inpFF); + cur = ggml_gelu(ctx0, cur); + cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur); + } + + cur = ggml_add(ctx0, cur, attn_out); + cur = ggml_add(ctx0, cur, inpL); + // input for next layer + inpL = cur; + } + + // norm + { + cur = ggml_norm(ctx0, inpL); + + cur = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, model.output_norm, cur), + cur), + ggml_repeat(ctx0, model.output_norm_b, cur)); + ggml_set_name(cur, "result_norm"); } + cur = ggml_mul_mat(ctx0, model.output, cur); + ggml_set_name(cur, "result_output"); + + ggml_build_forward_expand(gf, cur); + ggml_free(ctx0); return gf; } +static struct ggml_cgraph * llama_build_graph( + llama_context & lctx, + const llama_token * tokens, + const float * embd, + int n_tokens, + int n_past) { + const auto & model = lctx.model; + + struct ggml_cgraph * result = NULL; + + switch (model.arch) { + case LLM_ARCH_LLAMA: + { + result = llm_build_llama(lctx, tokens, embd, n_tokens, n_past); + } break; + case LLM_ARCH_FALCON: + { + result = llm_build_falcon(lctx, tokens, embd, n_tokens, n_past); + } break; + default: + GGML_ASSERT(false); + }; + + return result; +} + // evaluate the transformer // // - lctx: llama context @@ -2427,11 +2700,11 @@ static bool llama_eval_internal( // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas() ? 1 : n_threads; - struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; + struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2]; - GGML_ASSERT(strcmp(res->name, "result_output") == 0); - GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); + GGML_ASSERT(strcmp(res->name, "result_output") == 0); + GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); #if GGML_USE_MPI const int64_t n_layer = hparams.n_layer;