-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ENH] Add brute force operator (#1907)
## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - / - New functionality - Adds a brute force operator type ## Test plan *How are these changes tested?* - [ ] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Documentation Changes None
- Loading branch information
Showing
8 changed files
with
341 additions
and
56 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
mod types; | ||
|
||
pub use types::*; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
use crate::errors::{ChromaError, ErrorCodes}; | ||
use thiserror::Error; | ||
|
||
/// The distance function enum. | ||
/// # Description | ||
/// This enum defines the distance functions supported by indices in Chroma. | ||
/// # Variants | ||
/// - `Euclidean` - The Euclidean or l2 norm. | ||
/// - `Cosine` - The cosine distance. Specifically, 1 - cosine. | ||
/// - `InnerProduct` - The inner product. Specifically, 1 - inner product. | ||
/// # Notes | ||
/// See https://docs.trychroma.com/usage-guide#changing-the-distance-function | ||
#[derive(Clone, Debug, PartialEq)] | ||
pub(crate) enum DistanceFunction { | ||
Euclidean, | ||
Cosine, | ||
InnerProduct, | ||
} | ||
|
||
impl DistanceFunction { | ||
// TOOD: Should we error if mismatched dimensions? | ||
pub(crate) fn distance(&self, a: &[f32], b: &[f32]) -> f32 { | ||
// TODO: implement this in SSE/AVX SIMD | ||
// For now we write these as loops since we suspect that will more likely | ||
// lead to the compiler vectorizing the code. (We saw this on | ||
// Apple Silicon Macs who didn't have hand-rolled SIMD instructions in our | ||
// C++ code). | ||
match self { | ||
DistanceFunction::Euclidean => { | ||
let mut sum = 0.0; | ||
for i in 0..a.len() { | ||
sum += (a[i] - b[i]).powi(2); | ||
} | ||
sum | ||
} | ||
DistanceFunction::Cosine => { | ||
// For cosine we just assume the vectors have been normalized, since that | ||
// is what our indices expect. | ||
let mut sum = 0.0; | ||
for i in 0..a.len() { | ||
sum += a[i] * b[i]; | ||
} | ||
1.0_f32 - sum | ||
} | ||
DistanceFunction::InnerProduct => { | ||
let mut sum = 0.0; | ||
for i in 0..a.len() { | ||
sum += a[i] * b[i]; | ||
} | ||
1.0_f32 - sum | ||
} | ||
} | ||
} | ||
} | ||
|
||
#[derive(Error, Debug)] | ||
pub(crate) enum DistanceFunctionError { | ||
#[error("Invalid distance function `{0}`")] | ||
InvalidDistanceFunction(String), | ||
} | ||
|
||
impl ChromaError for DistanceFunctionError { | ||
fn code(&self) -> ErrorCodes { | ||
match self { | ||
DistanceFunctionError::InvalidDistanceFunction(_) => ErrorCodes::InvalidArgument, | ||
} | ||
} | ||
} | ||
|
||
impl TryFrom<&str> for DistanceFunction { | ||
type Error = DistanceFunctionError; | ||
|
||
fn try_from(value: &str) -> Result<Self, Self::Error> { | ||
match value { | ||
"l2" => Ok(DistanceFunction::Euclidean), | ||
"cosine" => Ok(DistanceFunction::Cosine), | ||
"ip" => Ok(DistanceFunction::InnerProduct), | ||
_ => Err(DistanceFunctionError::InvalidDistanceFunction( | ||
value.to_string(), | ||
)), | ||
} | ||
} | ||
} | ||
|
||
impl Into<String> for DistanceFunction { | ||
fn into(self) -> String { | ||
match self { | ||
DistanceFunction::Euclidean => "l2".to_string(), | ||
DistanceFunction::Cosine => "cosine".to_string(), | ||
DistanceFunction::InnerProduct => "ip".to_string(), | ||
} | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
use std::convert::TryInto; | ||
|
||
#[test] | ||
fn test_distance_function_try_from() { | ||
let distance_function: DistanceFunction = "l2".try_into().unwrap(); | ||
assert_eq!(distance_function, DistanceFunction::Euclidean); | ||
let distance_function: DistanceFunction = "cosine".try_into().unwrap(); | ||
assert_eq!(distance_function, DistanceFunction::Cosine); | ||
let distance_function: DistanceFunction = "ip".try_into().unwrap(); | ||
assert_eq!(distance_function, DistanceFunction::InnerProduct); | ||
} | ||
|
||
#[test] | ||
fn test_distance_function_into() { | ||
let distance_function: String = DistanceFunction::Euclidean.into(); | ||
assert_eq!(distance_function, "l2"); | ||
let distance_function: String = DistanceFunction::Cosine.into(); | ||
assert_eq!(distance_function, "cosine"); | ||
let distance_function: String = DistanceFunction::InnerProduct.into(); | ||
assert_eq!(distance_function, "ip"); | ||
} | ||
|
||
#[test] | ||
fn test_distance_function_l2sqr() { | ||
let a = vec![1.0, 2.0, 3.0]; | ||
let a_mag = (1.0_f32.powi(2) + 2.0_f32.powi(2) + 3.0_f32.powi(2)).sqrt(); | ||
let a_norm = vec![1.0 / a_mag, 2.0 / a_mag, 3.0 / a_mag]; | ||
let b = vec![4.0, 5.0, 6.0]; | ||
let b_mag = (4.0_f32.powi(2) + 5.0_f32.powi(2) + 6.0_f32.powi(2)).sqrt(); | ||
let b_norm = vec![4.0 / b_mag, 5.0 / b_mag, 6.0 / b_mag]; | ||
|
||
let l2_sqr = (1.0 - 4.0_f32).powi(2) + (2.0 - 5.0_f32).powi(2) + (3.0 - 6.0_f32).powi(2); | ||
let inner_product_sim = 1.0_f32 | ||
- a_norm | ||
.iter() | ||
.zip(b_norm.iter()) | ||
.map(|(a, b)| a * b) | ||
.sum::<f32>(); | ||
|
||
let distance_function: DistanceFunction = "l2".try_into().unwrap(); | ||
assert_eq!(distance_function.distance(&a, &b), l2_sqr); | ||
let distance_function: DistanceFunction = "ip".try_into().unwrap(); | ||
assert_eq!( | ||
distance_function.distance(&a_norm, &b_norm), | ||
inner_product_sim | ||
); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
use crate::{distance::DistanceFunction, execution::operator::Operator}; | ||
use async_trait::async_trait; | ||
|
||
/// The brute force k-nearest neighbors operator is responsible for computing the k-nearest neighbors | ||
/// of a given query vector against a set of vectors using brute force calculation. | ||
/// # Note | ||
/// - Callers should ensure that the input vectors are normalized if using the cosine similarity metric. | ||
#[derive(Debug)] | ||
pub struct BruteForceKnnOperator {} | ||
|
||
/// The input to the brute force k-nearest neighbors operator. | ||
/// # Parameters | ||
/// * `data` - The vectors to query against. | ||
/// * `query` - The query vector. | ||
/// * `k` - The number of nearest neighbors to find. | ||
/// * `distance_metric` - The distance metric to use. | ||
pub struct BruteForceKnnOperatorInput { | ||
pub data: Vec<Vec<f32>>, | ||
pub query: Vec<f32>, | ||
pub k: usize, | ||
pub distance_metric: DistanceFunction, | ||
} | ||
|
||
/// The output of the brute force k-nearest neighbors operator. | ||
/// # Parameters | ||
/// * `indices` - The indices of the nearest neighbors. This is a mask against the `query_vecs` input. | ||
/// One row for each query vector. | ||
/// * `distances` - The distances of the nearest neighbors. | ||
/// One row for each query vector. | ||
pub struct BruteForceKnnOperatorOutput { | ||
pub indices: Vec<usize>, | ||
pub distances: Vec<f32>, | ||
} | ||
|
||
#[async_trait] | ||
impl Operator<BruteForceKnnOperatorInput, BruteForceKnnOperatorOutput> for BruteForceKnnOperator { | ||
type Error = (); | ||
|
||
async fn run( | ||
&self, | ||
input: &BruteForceKnnOperatorInput, | ||
) -> Result<BruteForceKnnOperatorOutput, Self::Error> { | ||
// We could use a heap approach here, but for now we just sort the distances and take the | ||
// first k. | ||
let mut sorted_indices_distances = input | ||
.data | ||
.iter() | ||
.map(|data| input.distance_metric.distance(&input.query, data)) | ||
.enumerate() | ||
.collect::<Vec<(usize, f32)>>(); | ||
sorted_indices_distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); | ||
let (sorted_indices, sorted_distances) = sorted_indices_distances.drain(..input.k).unzip(); | ||
|
||
Ok(BruteForceKnnOperatorOutput { | ||
indices: sorted_indices, | ||
distances: sorted_distances, | ||
}) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
|
||
#[tokio::test] | ||
async fn test_brute_force_knn_l2sqr() { | ||
let operator = BruteForceKnnOperator {}; | ||
let input = BruteForceKnnOperatorInput { | ||
data: vec![ | ||
vec![0.0, 0.0, 0.0], | ||
vec![0.0, 1.0, 1.0], | ||
vec![7.0, 8.0, 9.0], | ||
], | ||
query: vec![0.0, 0.0, 0.0], | ||
k: 2, | ||
distance_metric: DistanceFunction::Euclidean, | ||
}; | ||
let output = operator.run(&input).await.unwrap(); | ||
assert_eq!(output.indices, vec![0, 1]); | ||
let distance_1 = 0.0_f32.powi(2) + 1.0_f32.powi(2) + 1.0_f32.powi(2); | ||
assert_eq!(output.distances, vec![0.0, distance_1]); | ||
} | ||
|
||
#[tokio::test] | ||
async fn test_brute_force_knn_cosine() { | ||
let operator = BruteForceKnnOperator {}; | ||
|
||
let norm_1 = (1.0_f32.powi(2) + 2.0_f32.powi(2) + 3.0_f32.powi(2)).sqrt(); | ||
let data_1 = vec![1.0 / norm_1, 2.0 / norm_1, 3.0 / norm_1]; | ||
|
||
let norm_2 = (0.0_f32.powi(2) + -1.0_f32.powi(2) + 6.0_f32.powi(2)).sqrt(); | ||
let data_2 = vec![0.0 / norm_2, -1.0 / norm_2, 6.0 / norm_2]; | ||
|
||
let input = BruteForceKnnOperatorInput { | ||
data: vec![vec![0.0, 1.0, 0.0], data_1.clone(), data_2.clone()], | ||
query: vec![0.0, 1.0, 0.0], | ||
k: 2, | ||
distance_metric: DistanceFunction::InnerProduct, | ||
}; | ||
let output = operator.run(&input).await.unwrap(); | ||
|
||
assert_eq!(output.indices, vec![0, 1]); | ||
let expected_distance_1 = | ||
1.0f32 - ((data_1[0] * 0.0) + (data_1[1] * 1.0) + (data_1[2] * 0.0)); | ||
assert_eq!(output.distances, vec![0.0, expected_distance_1]); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,3 @@ | ||
pub(super) mod brute_force_knn; | ||
pub(super) mod normalize_vectors; | ||
pub(super) mod pull_log; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
use crate::execution::operator::Operator; | ||
use async_trait::async_trait; | ||
|
||
const EPS: f32 = 1e-30; | ||
|
||
#[derive(Debug)] | ||
pub struct NormalizeVectorOperator {} | ||
|
||
pub struct NormalizeVectorOperatorInput { | ||
pub vectors: Vec<Vec<f32>>, | ||
} | ||
|
||
pub struct NormalizeVectorOperatorOutput { | ||
pub normalized_vectors: Vec<Vec<f32>>, | ||
} | ||
|
||
#[async_trait] | ||
impl Operator<NormalizeVectorOperatorInput, NormalizeVectorOperatorOutput> | ||
for NormalizeVectorOperator | ||
{ | ||
type Error = (); | ||
|
||
async fn run( | ||
&self, | ||
input: &NormalizeVectorOperatorInput, | ||
) -> Result<NormalizeVectorOperatorOutput, Self::Error> { | ||
// TODO: this should not have to reallocate the vectors. We can optimize this later. | ||
let mut normalized_vectors = Vec::with_capacity(input.vectors.len()); | ||
for vector in &input.vectors { | ||
let mut norm = 0.0; | ||
for x in vector { | ||
norm += x * x; | ||
} | ||
let norm = 1.0 / (norm.sqrt() + EPS); | ||
let normalized_vector = vector.iter().map(|x| x * norm).collect(); | ||
normalized_vectors.push(normalized_vector); | ||
} | ||
Ok(NormalizeVectorOperatorOutput { normalized_vectors }) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
|
||
const COMPARE_EPS: f32 = 1e-9; | ||
fn float_eps_eq(a: &[f32], b: &[f32]) -> bool { | ||
a.iter() | ||
.zip(b.iter()) | ||
.all(|(a, b)| (a - b).abs() < COMPARE_EPS) | ||
} | ||
|
||
#[tokio::test] | ||
async fn test_normalize_vector() { | ||
let operator = NormalizeVectorOperator {}; | ||
let input = NormalizeVectorOperatorInput { | ||
vectors: vec![ | ||
vec![1.0, 2.0, 3.0], | ||
vec![4.0, 5.0, 6.0], | ||
vec![7.0, 8.0, 9.0], | ||
], | ||
}; | ||
|
||
let output = operator.run(&input).await.unwrap(); | ||
let expected_output = NormalizeVectorOperatorOutput { | ||
normalized_vectors: vec![ | ||
vec![0.26726124, 0.5345225, 0.8017837], | ||
vec![0.45584232, 0.5698029, 0.6837635], | ||
vec![0.5025707, 0.5743665, 0.64616233], | ||
], | ||
}; | ||
|
||
for (a, b) in output | ||
.normalized_vectors | ||
.iter() | ||
.zip(expected_output.normalized_vectors.iter()) | ||
{ | ||
assert!(float_eps_eq(a, b), "{:?} != {:?}", a, b); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.