diff --git a/src/openai/models/mistral.rs b/src/openai/models/mistral.rs index bdccc24..fa82b3f 100644 --- a/src/openai/models/mistral.rs +++ b/src/openai/models/mistral.rs @@ -278,6 +278,10 @@ impl Attention { .reshape((b_sz, q_len, self.hidden_size))? .apply(&self.o_proj) } + + fn clear_cache(&mut self) { + self.kv_cache = None; + } } #[derive(Debug, Clone)] @@ -404,4 +408,10 @@ impl Model { .apply(&self.norm)? .apply(&self.lm_head) } + + pub fn clear_cache(&mut self) { + for block in &mut self.layers { + block.self_attn.clear_cache(); + } + } } diff --git a/src/openai/pipelines/mistral.rs b/src/openai/pipelines/mistral.rs index 4cf5146..cd0c4ec 100644 --- a/src/openai/pipelines/mistral.rs +++ b/src/openai/pipelines/mistral.rs @@ -478,6 +478,8 @@ impl<'s> ModulePipeline<'s> for Mistral7BPipeline { } } + self.mistral.clear_cache(); + Ok(( Some(choices), ChatCompletionUsageResponse {