diff --git a/.github/workflows/cargo_build.yml b/.github/workflows/cargo_build.yml index bc07d69..1cb2603 100644 --- a/.github/workflows/cargo_build.yml +++ b/.github/workflows/cargo_build.yml @@ -34,3 +34,4 @@ jobs: run: mkdir /tmp/models && cargo test --verbose env: LLAMA_CPP_TEST_MODELS: "/tmp/models" + LLAMA_EMBED_MODELS_DIR: "/tmp/models" diff --git a/crates/llama_cpp/src/model/mod.rs b/crates/llama_cpp/src/model/mod.rs index d775ca1..4af18f4 100644 --- a/crates/llama_cpp/src/model/mod.rs +++ b/crates/llama_cpp/src/model/mod.rs @@ -12,7 +12,7 @@ use futures::executor::block_on; use thiserror::Error; use tokio::sync::Mutex; use tokio::sync::RwLock; -use tracing::info; +use tracing::{info, trace, warn}; use backend::BackendRef; use llama_cpp_sys::{ @@ -302,13 +302,7 @@ impl LlamaModel { token.0 ); - unsafe { - CStr::from_ptr(llama_token_get_text( - **self.model, - token.0, - )) - } - .to_bytes() + unsafe { CStr::from_ptr(llama_token_get_text(**self.model, token.0)) }.to_bytes() } /// Converts the provided token into a `Vec` piece, using the model's vocabulary. @@ -459,28 +453,36 @@ impl LlamaModel { let ptr = llama_get_embeddings_ith(context, i as i32); slice_from_raw_parts(ptr, self.embedding_length) .as_ref() - .ok_or(LlamaContextError::DecodeFailed(1))? + .ok_or(LlamaContextError::EmbeddingsFailed( + "Could not parse embeddings".to_string(), + ))? }; - // normalize the embedding - let mut embed_vec = vec![0f32; self.embedding_length]; - let sum = embedding - .iter() - .map(move |x| x * x) - .reduce(move |a, b| a + b) - .ok_or(LlamaContextError::DecodeFailed(2))?; - - let norm = sum.sqrt(); - for (i, value) in embedding.iter().enumerate() { - embed_vec[i] = value / norm; - } - - out.push(embed_vec) + out.push(self.normalise_embedding(embedding)?) } Ok(out) } + /// Normalise an embeddings vector. + fn normalise_embedding(&self, embedding: &[f32]) -> Result, LlamaContextError> { + let mut embed_vec = vec![0f32; self.embedding_length]; + let sum = embedding + .iter() + .map(move |x| x * x) + .reduce(move |a, b| a + b) + .ok_or(LlamaContextError::EmbeddingsFailed( + "Could not normalise vector".to_string(), + ))?; + + let norm = sum.sqrt(); + for (i, value) in embedding.iter().enumerate() { + embed_vec[i] = value / norm; + } + + Ok(embed_vec) + } + /// Runs embeddings inference for the given inputs vector, returning the result. fn embeddings_process( &self, @@ -496,7 +498,12 @@ impl LlamaModel { } } - let batch_capacity = min(self.training_size, total_tokens); + let batch_capacity = if max_tokens > self.training_size { + warn!("Large embedding input requires a context larger than the model's training context."); + max_tokens + } else { + min(self.training_size, total_tokens) + }; let mut batch = Batch::new(batch_capacity, 0, inputs.len()); let mut out = Vec::with_capacity(inputs.len()); @@ -506,6 +513,8 @@ impl LlamaModel { ctx_params.embedding = true; ctx_params.n_threads = params.n_threads; ctx_params.n_threads_batch = params.n_threads_batch; + ctx_params.n_ctx = batch_capacity as u32; + ctx_params.n_batch = batch_capacity as u32; // SAFETY: due to `_model` being declared in the `LlamaContext`, `self` must live // for at least the lifetime of `LlamaContext`. llama_new_context_with_model(**self.model, ctx_params) @@ -518,11 +527,13 @@ impl LlamaModel { let mut batch_input_count = 0; for input in inputs { if batch.tokens() + input.len() > batch_capacity { + trace!("Decoding {} embedding tokens", batch.tokens()); out.append(&mut self.embeddings_decode(context, &batch, batch_input_count)?); batch.clear(); batch_input_count = 0; } + trace!("Adding {} tokens to batch", input.len()); for (i, token) in input.iter().enumerate() { batch.add(*token, i, &[batch_input_count as i32], false); } @@ -530,6 +541,7 @@ impl LlamaModel { } if 0 < batch_input_count { + trace!("Decoding remaining {} embedding tokens", batch.tokens()); out.append(&mut self.embeddings_decode(context, &batch, batch_input_count)?); } diff --git a/crates/llama_cpp/src/session/mod.rs b/crates/llama_cpp/src/session/mod.rs index ec91504..8e75ba4 100644 --- a/crates/llama_cpp/src/session/mod.rs +++ b/crates/llama_cpp/src/session/mod.rs @@ -110,6 +110,10 @@ pub enum LlamaContextError { /// An error occurred on the other side of the FFI boundary; check your logs. #[error("advancing context failed (error code {0})")] DecodeFailed(i32), + + /// An error occurred on the other side of the FFI boundary; check your logs. + #[error("failed to process embeddings (reason: {0})")] + EmbeddingsFailed(String), } impl LlamaSession { diff --git a/crates/llama_cpp_tests/src/lib.rs b/crates/llama_cpp_tests/src/lib.rs index 0ca34c7..79fc1e3 100644 --- a/crates/llama_cpp_tests/src/lib.rs +++ b/crates/llama_cpp_tests/src/lib.rs @@ -7,6 +7,7 @@ mod tests { use std::io; use std::io::Write; + use std::path::Path; use std::time::Duration; use futures::StreamExt; @@ -14,19 +15,12 @@ mod tests { use tokio::time::Instant; use llama_cpp::standard_sampler::StandardSampler; - use llama_cpp::{CompletionHandle, LlamaModel, LlamaParams, SessionParams, TokensToStrings}; + use llama_cpp::{ + CompletionHandle, EmbeddingsParams, LlamaModel, LlamaParams, SessionParams, TokensToStrings, + }; - async fn list_models() -> Vec { - let dir = std::env::var("LLAMA_CPP_TEST_MODELS").unwrap_or_else(|_| { - eprintln!( - "LLAMA_CPP_TEST_MODELS environment variable not set. \ - Please set this to the directory containing one or more GGUF models." - ); - - std::process::exit(1) - }); - - let dir = std::path::Path::new(&dir); + async fn list_models(dir: impl AsRef) -> Vec { + let dir = dir.as_ref(); if !dir.is_dir() { panic!("\"{}\" is not a directory", dir.to_string_lossy()); @@ -53,7 +47,14 @@ mod tests { #[ignore] #[tokio::test] async fn load_models() { - let models = list_models().await; + let dir = std::env::var("LLAMA_CPP_TEST_MODELS").unwrap_or_else(|_| { + panic!( + "LLAMA_CPP_TEST_MODELS environment variable not set. \ + Please set this to the directory containing one or more GGUF models." + ); + }); + + let models = list_models(dir).await; for model in models { println!("Loading model: {}", model); @@ -65,7 +66,14 @@ mod tests { #[tokio::test] async fn execute_completions() { - let models = list_models().await; + let dir = std::env::var("LLAMA_CPP_TEST_MODELS").unwrap_or_else(|_| { + panic!( + "LLAMA_CPP_TEST_MODELS environment variable not set. \ + Please set this to the directory containing one or more GGUF models." + ); + }); + + let models = list_models(dir).await; for model in models { let mut params = LlamaParams::default(); @@ -124,4 +132,47 @@ mod tests { println!(); } } + + #[tokio::test] + async fn embed() { + let dir = std::env::var("LLAMA_EMBED_MODELS_DIR").unwrap_or_else(|_| { + panic!( + "LLAMA_EMBED_MODELS_DIR environment variable not set. \ + Please set this to the directory containing one or more embedding GGUF models." + ); + }); + + let models = list_models(dir).await; + + for model in models { + let params = LlamaParams::default(); + let model = LlamaModel::load_from_file_async(model, params) + .await + .expect("Failed to load model"); + + let mut input = vec![]; + + for _phrase_idx in 0..2 { + let mut phrase = String::new(); + for _word_idx in 0..3000 { + phrase.push_str("word "); + } + phrase.truncate(phrase.len() - 1); + input.push(phrase); + } + + let params = EmbeddingsParams::default(); + let res = model + .embeddings_async(&input, params) + .await + .expect("Failed to infer embeddings"); + + for embedding in &res { + assert!(embedding[0].is_normal(), "Embedding value isn't normal"); + assert!(embedding[0] >= 0f32, "Embedding value isn't normalised"); + assert!(embedding[0] <= 1f32, "Embedding value isn't normalised"); + } + println!("{:?}", res[0]); + } + } }