Skip to content

Commit

Permalink
llama: clear kv cache for each run_model call
Browse files Browse the repository at this point in the history
This commit adds a call to `llama_kv_cache_clear` for each call to
`run_model`. This is done because the same sequence id is currently
being used for each call to `run_model` which can cause tokens from a
previous call to be in the catch. This can cause the model to use tokens
from a previous decode call in the attention mechanism which can cause
the model to generate incorrect information.

Signed-off-by: Daniel Bevenius <[email protected]>
  • Loading branch information
danbev committed Dec 15, 2023
1 parent cbc1fdc commit 5333d50
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
6 changes: 5 additions & 1 deletion crates/llm-chain-llama/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::options::LlamaInvocation;
use anyhow::Result;
use llm_chain_llama_sys::{
llama_context, llama_context_default_params, llama_context_params, llama_decode, llama_eval,
llama_free, llama_get_embeddings, llama_get_logits, llama_get_logits_ith,
llama_free, llama_get_embeddings, llama_get_logits, llama_get_logits_ith, llama_kv_cache_clear,
llama_load_model_from_file, llama_model, llama_n_embd, llama_n_vocab,
llama_new_context_with_model, llama_sample_repetition_penalties, llama_sample_tail_free,
llama_sample_temperature, llama_sample_token, llama_sample_token_greedy,
Expand Down Expand Up @@ -310,6 +310,10 @@ impl LLamaContext {
unsafe { llama_token_nl(self.model) }
}

pub fn llama_kv_cache_clear(&self) {
unsafe { llama_kv_cache_clear(self.ctx) }
}

pub fn llama_token_to_piece(
&self,
token_id: i32,
Expand Down
11 changes: 11 additions & 0 deletions crates/llm-chain-llama/src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,17 @@ impl Executor {
let context_size = context_size;
let context = context.blocking_lock();

// The following clears the Key-Value cache to allow conversational
// (chat) applications to be able to call run_model multiple times
// using the same context. Without this, and because the same
// sequence id is used below, the cache can contain tokens from
// a previous interaction which may cause the model to generate
// a response that is not appropriate for the current prompt.
//
// TODO(danbev) Is there a better way to do this, perhaps by using
// sequence ids in some way?
context.llama_kv_cache_clear();

let tokenized_stop_prompt = tokenize(
&context,
input
Expand Down

0 comments on commit 5333d50

Please sign in to comment.