Skip to content

Commit

Permalink
falcon : CPU inference working
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Aug 22, 2023
1 parent 085228e commit 3c7c325
Showing 1 changed file with 286 additions and 13 deletions.
299 changes: 286 additions & 13 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1031,8 +1031,6 @@ struct llama_context {
// key + value cache for the self attention
struct llama_kv_cache kv_self;

size_t mem_per_token = 0;

// decode output (2-dimensional array: [n_tokens][n_vocab])
std::vector<float> logits;
bool logits_all = false;
Expand Down Expand Up @@ -2014,7 +2012,7 @@ static bool llama_model_load(
return true;
}

static struct ggml_cgraph * llama_build_graph(
static struct ggml_cgraph * llm_build_llama(
llama_context & lctx,
const llama_token * tokens,
const float * embd,
Expand Down Expand Up @@ -2048,8 +2046,7 @@ static struct ggml_cgraph * llama_build_graph(

const int n_gpu_layers = model.n_gpu_layers;

auto & mem_per_token = lctx.mem_per_token;
auto & buf_compute = lctx.buf_compute;
auto & buf_compute = lctx.buf_compute;

struct ggml_init_params params = {
/*.mem_size =*/ buf_compute.size,
Expand Down Expand Up @@ -2340,20 +2337,296 @@ static struct ggml_cgraph * llama_build_graph(
cur = ggml_mul_mat(ctx0, model.output, cur);
ggml_set_name(cur, "result_output");

// logits -> probs
//cur = ggml_soft_max_inplace(ctx0, cur);

ggml_build_forward_expand(gf, cur);

if (mem_per_token == 0) {
mem_per_token = ggml_used_mem(ctx0)/N;
ggml_free(ctx0);

return gf;
}

static struct ggml_cgraph * llm_build_falcon(
llama_context & lctx,
const llama_token * tokens,
const float * embd,
int n_tokens,
int n_past) {

GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT

const int N = n_tokens;

const auto & model = lctx.model;
const auto & hparams = model.hparams;

const auto & kv_self = lctx.kv_self;

GGML_ASSERT(!!kv_self.ctx);

const int64_t n_embd = hparams.n_embd;
const int64_t n_layer = hparams.n_layer;
const int64_t n_ctx = hparams.n_ctx;
const int64_t n_head = hparams.n_head;
const int64_t n_head_kv = hparams.n_head_kv;
const int64_t n_embd_head = hparams.n_embd_head();
//const int64_t n_embd_gqa = hparams.n_embd_gqa();

GGML_ASSERT(n_embd_head == hparams.n_rot);

auto & buf_compute = lctx.buf_compute;

struct ggml_init_params params = {
/*.mem_size =*/ buf_compute.size,
/*.mem_buffer =*/ buf_compute.data,
/*.no_alloc =*/ false,
};

params.no_alloc = true;

struct ggml_context * ctx0 = ggml_init(params);

ggml_cgraph * gf = ggml_new_graph(ctx0);

struct ggml_tensor * cur;
struct ggml_tensor * inpL;

if (tokens) {
struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);

ggml_allocr_alloc(lctx.alloc, inp_tokens);
if (!ggml_allocr_is_measure(lctx.alloc)) {
memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens));
}
ggml_set_name(inp_tokens, "inp_tokens");

inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
} else {
#ifdef GGML_USE_MPI
GGML_ASSERT(false && "not implemented");
#endif

inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N);

ggml_allocr_alloc(lctx.alloc, inpL);
if (!ggml_allocr_is_measure(lctx.alloc)) {
memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL));
}
}

struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
ggml_allocr_alloc(lctx.alloc, KQ_scale);
if (!ggml_allocr_is_measure(lctx.alloc)) {
ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
}
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");

for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * cur;
struct ggml_tensor * layernorm_output;

// self-attention
{
layernorm_output = ggml_norm(ctx0, inpL);

layernorm_output = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, model.layers[il].attn_norm, layernorm_output),
layernorm_output),
ggml_repeat(ctx0, model.layers[il].attn_norm_b, layernorm_output));

if ( hparams.n_head_kv == 8 ) { // Falcon-40B
cur = ggml_norm(ctx0, inpL);

cur = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, model.layers[il].attn_norm_2, cur),
cur),
ggml_repeat(ctx0, model.layers[il].attn_norm_2_b, cur));
}
else { // Falcon 7B
cur = layernorm_output;
}

// compute QKV

cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);

// Note that the strides for Kcur, Vcur are set up so that the
// resulting views are misaligned with the tensor's storage
// (by applying the K/V offset we shift the tensor's original
// view to stick out behind the viewed QKV tensor's allocated
// memory, so to say). This is ok because no actual accesses
// happen to that out-of-range memory, but it can require some
// trickery when trying to accurately dump these views for
// debugging.

struct ggml_tensor * Qcur = ggml_view_3d(
ctx0, cur, n_embd_head, n_head, N,
n_embd_head * ggml_type_size(GGML_TYPE_F32),
n_embd_head * (n_head + 2 * n_head_kv) * ggml_type_size(GGML_TYPE_F32),
0);

struct ggml_tensor * Kcur = ggml_view_3d(
ctx0, cur, n_embd_head, n_head_kv, N,
n_embd_head * ggml_type_size(GGML_TYPE_F32),
n_embd_head * (n_head + 2 * n_head_kv) * ggml_type_size(GGML_TYPE_F32),
n_embd_head * n_head * ggml_type_size(GGML_TYPE_F32));

struct ggml_tensor * Vcur = ggml_view_3d(
ctx0, cur, n_embd_head, n_head_kv, N,
n_embd_head * ggml_type_size(GGML_TYPE_F32),
n_embd_head * (n_head + 2 * n_head_kv) * ggml_type_size(GGML_TYPE_F32),
n_embd_head * (n_head + n_head_kv) * ggml_type_size(GGML_TYPE_F32));

// using mode = 2 for neox mode
Qcur = ggml_rope_inplace(ctx0, Qcur, n_past, n_embd_head, 2, 0);
Kcur = ggml_rope_inplace(ctx0, Kcur, n_past, n_embd_head, 2, 0);

// store key and value to memory
{
struct ggml_tensor* k = ggml_view_1d(
ctx0, kv_self.k, N * n_head_kv * n_embd_head,
(ggml_element_size(kv_self.k) * n_head_kv * n_embd_head) *
(il * n_ctx + n_past));
struct ggml_tensor* v = ggml_view_1d(
ctx0, kv_self.v, N * n_head_kv * n_embd_head,
(ggml_element_size(kv_self.v) * n_head_kv * n_embd_head) *
(il * n_ctx + n_past));

ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
}

struct ggml_tensor * K = ggml_permute(
ctx0,
ggml_reshape_3d(
ctx0,
ggml_view_1d(ctx0, kv_self.k, (n_past + N) * n_head_kv * n_embd_head,
il * n_ctx *
ggml_element_size(kv_self.k) *
n_head_kv *
n_embd_head),
n_embd_head, n_head_kv, n_past + N),
0, 2, 1, 3);

// K * Q

// K = ggml_cont(ctx0, ggml_repeat2(ctx0, K, repeat_dummy));

struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);

// KQ_scaled = KQ / sqrt(n_embd/n_head)
struct ggml_tensor * KQ_scaled =
ggml_scale_inplace(ctx0,
KQ,
ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd_head)))
);

// KQ_masked = mask_past(KQ_scaled)
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);

// KQ = soft_max(KQ_masked)
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);

// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
struct ggml_tensor* V = ggml_permute(
ctx0,
ggml_reshape_3d(
ctx0,
ggml_view_1d(ctx0, kv_self.v, (n_past + N) * n_head_kv * n_embd_head,
il * n_ctx *
ggml_element_size(kv_self.v) *
n_head_kv *
n_embd_head),
n_embd_head, n_head_kv, n_past + N),
0, 2, 1, 3);

// V = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_repeat2(ctx0, V, repeat_dummy)));
V = ggml_cont(ctx0, ggml_transpose(ctx0, V));

// KQV = transpose(V) * KQ_soft_max
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);

// KQV_merged = KQV.permute(0, 2, 1, 3)
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);

// cur = KQV_merged.contiguous().view(n_embd, N)
cur = ggml_cpy(ctx0,
KQV_merged,
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));

// projection
{
cur = ggml_mul_mat(ctx0,
model.layers[il].wo,
cur);
}
}

struct ggml_tensor* inpFF = layernorm_output;
struct ggml_tensor* attn_out = ggml_cpy(
ctx0, cur, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));

{
cur = ggml_mul_mat(ctx0, model.layers[il].w3, inpFF);
cur = ggml_gelu(ctx0, cur);
cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur);
}

cur = ggml_add(ctx0, cur, attn_out);
cur = ggml_add(ctx0, cur, inpL);
// input for next layer
inpL = cur;
}

// norm
{
cur = ggml_norm(ctx0, inpL);

cur = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, model.output_norm, cur),
cur),
ggml_repeat(ctx0, model.output_norm_b, cur));
ggml_set_name(cur, "result_norm");
}

cur = ggml_mul_mat(ctx0, model.output, cur);
ggml_set_name(cur, "result_output");

ggml_build_forward_expand(gf, cur);

ggml_free(ctx0);

return gf;
}

static struct ggml_cgraph * llama_build_graph(
llama_context & lctx,
const llama_token * tokens,
const float * embd,
int n_tokens,
int n_past) {
const auto & model = lctx.model;

struct ggml_cgraph * result = NULL;

switch (model.arch) {
case LLM_ARCH_LLAMA:
{
result = llm_build_llama(lctx, tokens, embd, n_tokens, n_past);
} break;
case LLM_ARCH_FALCON:
{
result = llm_build_falcon(lctx, tokens, embd, n_tokens, n_past);
} break;
default:
GGML_ASSERT(false);
};

return result;
}

// evaluate the transformer
//
// - lctx: llama context
Expand Down Expand Up @@ -2427,11 +2700,11 @@ static bool llama_eval_internal(
// otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas() ? 1 : n_threads;

struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2];

GGML_ASSERT(strcmp(res->name, "result_output") == 0);
GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
GGML_ASSERT(strcmp(res->name, "result_output") == 0);
GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);

#if GGML_USE_MPI
const int64_t n_layer = hparams.n_layer;
Expand Down

0 comments on commit 3c7c325

Please sign in to comment.