Skip to content

Commit

Permalink
Merge pull request #52 from edgenai/fix/embeddings
Browse files Browse the repository at this point in the history
handle large embedding inputs
  • Loading branch information
pedro-devv authored Mar 1, 2024
2 parents 1c02450 + f8358eb commit a90d8d2
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 38 deletions.
1 change: 1 addition & 0 deletions .github/workflows/cargo_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@ jobs:
run: mkdir /tmp/models && cargo test --verbose
env:
LLAMA_CPP_TEST_MODELS: "/tmp/models"
LLAMA_EMBED_MODELS_DIR: "/tmp/models"
60 changes: 36 additions & 24 deletions crates/llama_cpp/src/model/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use futures::executor::block_on;
use thiserror::Error;
use tokio::sync::Mutex;
use tokio::sync::RwLock;
use tracing::info;
use tracing::{info, trace, warn};

use backend::BackendRef;
use llama_cpp_sys::{
Expand Down Expand Up @@ -302,13 +302,7 @@ impl LlamaModel {
token.0
);

unsafe {
CStr::from_ptr(llama_token_get_text(
**self.model,
token.0,
))
}
.to_bytes()
unsafe { CStr::from_ptr(llama_token_get_text(**self.model, token.0)) }.to_bytes()
}

/// Converts the provided token into a `Vec<u8>` piece, using the model's vocabulary.
Expand Down Expand Up @@ -459,28 +453,36 @@ impl LlamaModel {
let ptr = llama_get_embeddings_ith(context, i as i32);
slice_from_raw_parts(ptr, self.embedding_length)
.as_ref()
.ok_or(LlamaContextError::DecodeFailed(1))?
.ok_or(LlamaContextError::EmbeddingsFailed(
"Could not parse embeddings".to_string(),
))?
};

// normalize the embedding
let mut embed_vec = vec![0f32; self.embedding_length];
let sum = embedding
.iter()
.map(move |x| x * x)
.reduce(move |a, b| a + b)
.ok_or(LlamaContextError::DecodeFailed(2))?;

let norm = sum.sqrt();
for (i, value) in embedding.iter().enumerate() {
embed_vec[i] = value / norm;
}

out.push(embed_vec)
out.push(self.normalise_embedding(embedding)?)
}

Ok(out)
}

/// Normalise an embeddings vector.
fn normalise_embedding(&self, embedding: &[f32]) -> Result<Vec<f32>, LlamaContextError> {
let mut embed_vec = vec![0f32; self.embedding_length];
let sum = embedding
.iter()
.map(move |x| x * x)
.reduce(move |a, b| a + b)
.ok_or(LlamaContextError::EmbeddingsFailed(
"Could not normalise vector".to_string(),
))?;

let norm = sum.sqrt();
for (i, value) in embedding.iter().enumerate() {
embed_vec[i] = value / norm;
}

Ok(embed_vec)
}

/// Runs embeddings inference for the given inputs vector, returning the result.
fn embeddings_process(
&self,
Expand All @@ -496,7 +498,12 @@ impl LlamaModel {
}
}

let batch_capacity = min(self.training_size, total_tokens);
let batch_capacity = if max_tokens > self.training_size {
warn!("Large embedding input requires a context larger than the model's training context.");
max_tokens
} else {
min(self.training_size, total_tokens)
};
let mut batch = Batch::new(batch_capacity, 0, inputs.len());
let mut out = Vec::with_capacity(inputs.len());

Expand All @@ -506,6 +513,8 @@ impl LlamaModel {
ctx_params.embedding = true;
ctx_params.n_threads = params.n_threads;
ctx_params.n_threads_batch = params.n_threads_batch;
ctx_params.n_ctx = batch_capacity as u32;
ctx_params.n_batch = batch_capacity as u32;
// SAFETY: due to `_model` being declared in the `LlamaContext`, `self` must live
// for at least the lifetime of `LlamaContext`.
llama_new_context_with_model(**self.model, ctx_params)
Expand All @@ -518,18 +527,21 @@ impl LlamaModel {
let mut batch_input_count = 0;
for input in inputs {
if batch.tokens() + input.len() > batch_capacity {
trace!("Decoding {} embedding tokens", batch.tokens());
out.append(&mut self.embeddings_decode(context, &batch, batch_input_count)?);
batch.clear();
batch_input_count = 0;
}

trace!("Adding {} tokens to batch", input.len());
for (i, token) in input.iter().enumerate() {
batch.add(*token, i, &[batch_input_count as i32], false);
}
batch_input_count += 1;
}

if 0 < batch_input_count {
trace!("Decoding remaining {} embedding tokens", batch.tokens());
out.append(&mut self.embeddings_decode(context, &batch, batch_input_count)?);
}

Expand Down
4 changes: 4 additions & 0 deletions crates/llama_cpp/src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ pub enum LlamaContextError {
/// An error occurred on the other side of the FFI boundary; check your logs.
#[error("advancing context failed (error code {0})")]
DecodeFailed(i32),

/// An error occurred on the other side of the FFI boundary; check your logs.
#[error("failed to process embeddings (reason: {0})")]
EmbeddingsFailed(String),
}

impl LlamaSession {
Expand Down
79 changes: 65 additions & 14 deletions crates/llama_cpp_tests/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,20 @@
mod tests {
use std::io;
use std::io::Write;
use std::path::Path;
use std::time::Duration;

use futures::StreamExt;
use tokio::select;
use tokio::time::Instant;

use llama_cpp::standard_sampler::StandardSampler;
use llama_cpp::{CompletionHandle, LlamaModel, LlamaParams, SessionParams, TokensToStrings};
use llama_cpp::{
CompletionHandle, EmbeddingsParams, LlamaModel, LlamaParams, SessionParams, TokensToStrings,
};

async fn list_models() -> Vec<String> {
let dir = std::env::var("LLAMA_CPP_TEST_MODELS").unwrap_or_else(|_| {
eprintln!(
"LLAMA_CPP_TEST_MODELS environment variable not set. \
Please set this to the directory containing one or more GGUF models."
);

std::process::exit(1)
});

let dir = std::path::Path::new(&dir);
async fn list_models(dir: impl AsRef<Path>) -> Vec<String> {
let dir = dir.as_ref();

if !dir.is_dir() {
panic!("\"{}\" is not a directory", dir.to_string_lossy());
Expand All @@ -53,7 +47,14 @@ mod tests {
#[ignore]
#[tokio::test]
async fn load_models() {
let models = list_models().await;
let dir = std::env::var("LLAMA_CPP_TEST_MODELS").unwrap_or_else(|_| {
panic!(
"LLAMA_CPP_TEST_MODELS environment variable not set. \
Please set this to the directory containing one or more GGUF models."
);
});

let models = list_models(dir).await;

for model in models {
println!("Loading model: {}", model);
Expand All @@ -65,7 +66,14 @@ mod tests {

#[tokio::test]
async fn execute_completions() {
let models = list_models().await;
let dir = std::env::var("LLAMA_CPP_TEST_MODELS").unwrap_or_else(|_| {
panic!(
"LLAMA_CPP_TEST_MODELS environment variable not set. \
Please set this to the directory containing one or more GGUF models."
);
});

let models = list_models(dir).await;

for model in models {
let mut params = LlamaParams::default();
Expand Down Expand Up @@ -124,4 +132,47 @@ mod tests {
println!();
}
}

#[tokio::test]
async fn embed() {
let dir = std::env::var("LLAMA_EMBED_MODELS_DIR").unwrap_or_else(|_| {
panic!(
"LLAMA_EMBED_MODELS_DIR environment variable not set. \
Please set this to the directory containing one or more embedding GGUF models."
);
});

let models = list_models(dir).await;

for model in models {
let params = LlamaParams::default();
let model = LlamaModel::load_from_file_async(model, params)
.await
.expect("Failed to load model");

let mut input = vec![];

for _phrase_idx in 0..2 {
let mut phrase = String::new();
for _word_idx in 0..3000 {
phrase.push_str("word ");
}
phrase.truncate(phrase.len() - 1);
input.push(phrase);
}

let params = EmbeddingsParams::default();
let res = model
.embeddings_async(&input, params)
.await
.expect("Failed to infer embeddings");

for embedding in &res {
assert!(embedding[0].is_normal(), "Embedding value isn't normal");
assert!(embedding[0] >= 0f32, "Embedding value isn't normalised");
assert!(embedding[0] <= 1f32, "Embedding value isn't normalised");
}
println!("{:?}", res[0]);
}
}
}

0 comments on commit a90d8d2

Please sign in to comment.