Skip to content

Commit

Permalink
[ENH] Add brute force operator (#1907)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
	 - /
 - New functionality
	 - Adds a brute force operator type
	 
## Test plan
*How are these changes tested?*
- [ ] Tests pass locally with `pytest` for python, `yarn test` for js,
`cargo test` for rust

## Documentation Changes
None
  • Loading branch information
HammadB authored Mar 21, 2024
1 parent 7417f3d commit f5e173d
Show file tree
Hide file tree
Showing 8 changed files with 341 additions and 56 deletions.
3 changes: 3 additions & 0 deletions rust/worker/src/distance/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
mod types;

pub use types::*;
145 changes: 145 additions & 0 deletions rust/worker/src/distance/types.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
use crate::errors::{ChromaError, ErrorCodes};
use thiserror::Error;

/// The distance function enum.
/// # Description
/// This enum defines the distance functions supported by indices in Chroma.
/// # Variants
/// - `Euclidean` - The Euclidean or l2 norm.
/// - `Cosine` - The cosine distance. Specifically, 1 - cosine.
/// - `InnerProduct` - The inner product. Specifically, 1 - inner product.
/// # Notes
/// See https://docs.trychroma.com/usage-guide#changing-the-distance-function
#[derive(Clone, Debug, PartialEq)]
pub(crate) enum DistanceFunction {
Euclidean,
Cosine,
InnerProduct,
}

impl DistanceFunction {
// TOOD: Should we error if mismatched dimensions?
pub(crate) fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
// TODO: implement this in SSE/AVX SIMD
// For now we write these as loops since we suspect that will more likely
// lead to the compiler vectorizing the code. (We saw this on
// Apple Silicon Macs who didn't have hand-rolled SIMD instructions in our
// C++ code).
match self {
DistanceFunction::Euclidean => {
let mut sum = 0.0;
for i in 0..a.len() {
sum += (a[i] - b[i]).powi(2);
}
sum
}
DistanceFunction::Cosine => {
// For cosine we just assume the vectors have been normalized, since that
// is what our indices expect.
let mut sum = 0.0;
for i in 0..a.len() {
sum += a[i] * b[i];
}
1.0_f32 - sum
}
DistanceFunction::InnerProduct => {
let mut sum = 0.0;
for i in 0..a.len() {
sum += a[i] * b[i];
}
1.0_f32 - sum
}
}
}
}

#[derive(Error, Debug)]
pub(crate) enum DistanceFunctionError {
#[error("Invalid distance function `{0}`")]
InvalidDistanceFunction(String),
}

impl ChromaError for DistanceFunctionError {
fn code(&self) -> ErrorCodes {
match self {
DistanceFunctionError::InvalidDistanceFunction(_) => ErrorCodes::InvalidArgument,
}
}
}

impl TryFrom<&str> for DistanceFunction {
type Error = DistanceFunctionError;

fn try_from(value: &str) -> Result<Self, Self::Error> {
match value {
"l2" => Ok(DistanceFunction::Euclidean),
"cosine" => Ok(DistanceFunction::Cosine),
"ip" => Ok(DistanceFunction::InnerProduct),
_ => Err(DistanceFunctionError::InvalidDistanceFunction(
value.to_string(),
)),
}
}
}

impl Into<String> for DistanceFunction {
fn into(self) -> String {
match self {
DistanceFunction::Euclidean => "l2".to_string(),
DistanceFunction::Cosine => "cosine".to_string(),
DistanceFunction::InnerProduct => "ip".to_string(),
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use std::convert::TryInto;

#[test]
fn test_distance_function_try_from() {
let distance_function: DistanceFunction = "l2".try_into().unwrap();
assert_eq!(distance_function, DistanceFunction::Euclidean);
let distance_function: DistanceFunction = "cosine".try_into().unwrap();
assert_eq!(distance_function, DistanceFunction::Cosine);
let distance_function: DistanceFunction = "ip".try_into().unwrap();
assert_eq!(distance_function, DistanceFunction::InnerProduct);
}

#[test]
fn test_distance_function_into() {
let distance_function: String = DistanceFunction::Euclidean.into();
assert_eq!(distance_function, "l2");
let distance_function: String = DistanceFunction::Cosine.into();
assert_eq!(distance_function, "cosine");
let distance_function: String = DistanceFunction::InnerProduct.into();
assert_eq!(distance_function, "ip");
}

#[test]
fn test_distance_function_l2sqr() {
let a = vec![1.0, 2.0, 3.0];
let a_mag = (1.0_f32.powi(2) + 2.0_f32.powi(2) + 3.0_f32.powi(2)).sqrt();
let a_norm = vec![1.0 / a_mag, 2.0 / a_mag, 3.0 / a_mag];
let b = vec![4.0, 5.0, 6.0];
let b_mag = (4.0_f32.powi(2) + 5.0_f32.powi(2) + 6.0_f32.powi(2)).sqrt();
let b_norm = vec![4.0 / b_mag, 5.0 / b_mag, 6.0 / b_mag];

let l2_sqr = (1.0 - 4.0_f32).powi(2) + (2.0 - 5.0_f32).powi(2) + (3.0 - 6.0_f32).powi(2);
let inner_product_sim = 1.0_f32
- a_norm
.iter()
.zip(b_norm.iter())
.map(|(a, b)| a * b)
.sum::<f32>();

let distance_function: DistanceFunction = "l2".try_into().unwrap();
assert_eq!(distance_function.distance(&a, &b), l2_sqr);
let distance_function: DistanceFunction = "ip".try_into().unwrap();
assert_eq!(
distance_function.distance(&a_norm, &b_norm),
inner_product_sim
);
}
}
107 changes: 107 additions & 0 deletions rust/worker/src/execution/operators/brute_force_knn.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
use crate::{distance::DistanceFunction, execution::operator::Operator};
use async_trait::async_trait;

/// The brute force k-nearest neighbors operator is responsible for computing the k-nearest neighbors
/// of a given query vector against a set of vectors using brute force calculation.
/// # Note
/// - Callers should ensure that the input vectors are normalized if using the cosine similarity metric.
#[derive(Debug)]
pub struct BruteForceKnnOperator {}

/// The input to the brute force k-nearest neighbors operator.
/// # Parameters
/// * `data` - The vectors to query against.
/// * `query` - The query vector.
/// * `k` - The number of nearest neighbors to find.
/// * `distance_metric` - The distance metric to use.
pub struct BruteForceKnnOperatorInput {
pub data: Vec<Vec<f32>>,
pub query: Vec<f32>,
pub k: usize,
pub distance_metric: DistanceFunction,
}

/// The output of the brute force k-nearest neighbors operator.
/// # Parameters
/// * `indices` - The indices of the nearest neighbors. This is a mask against the `query_vecs` input.
/// One row for each query vector.
/// * `distances` - The distances of the nearest neighbors.
/// One row for each query vector.
pub struct BruteForceKnnOperatorOutput {
pub indices: Vec<usize>,
pub distances: Vec<f32>,
}

#[async_trait]
impl Operator<BruteForceKnnOperatorInput, BruteForceKnnOperatorOutput> for BruteForceKnnOperator {
type Error = ();

async fn run(
&self,
input: &BruteForceKnnOperatorInput,
) -> Result<BruteForceKnnOperatorOutput, Self::Error> {
// We could use a heap approach here, but for now we just sort the distances and take the
// first k.
let mut sorted_indices_distances = input
.data
.iter()
.map(|data| input.distance_metric.distance(&input.query, data))
.enumerate()
.collect::<Vec<(usize, f32)>>();
sorted_indices_distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
let (sorted_indices, sorted_distances) = sorted_indices_distances.drain(..input.k).unzip();

Ok(BruteForceKnnOperatorOutput {
indices: sorted_indices,
distances: sorted_distances,
})
}
}

#[cfg(test)]
mod tests {
use super::*;

#[tokio::test]
async fn test_brute_force_knn_l2sqr() {
let operator = BruteForceKnnOperator {};
let input = BruteForceKnnOperatorInput {
data: vec![
vec![0.0, 0.0, 0.0],
vec![0.0, 1.0, 1.0],
vec![7.0, 8.0, 9.0],
],
query: vec![0.0, 0.0, 0.0],
k: 2,
distance_metric: DistanceFunction::Euclidean,
};
let output = operator.run(&input).await.unwrap();
assert_eq!(output.indices, vec![0, 1]);
let distance_1 = 0.0_f32.powi(2) + 1.0_f32.powi(2) + 1.0_f32.powi(2);
assert_eq!(output.distances, vec![0.0, distance_1]);
}

#[tokio::test]
async fn test_brute_force_knn_cosine() {
let operator = BruteForceKnnOperator {};

let norm_1 = (1.0_f32.powi(2) + 2.0_f32.powi(2) + 3.0_f32.powi(2)).sqrt();
let data_1 = vec![1.0 / norm_1, 2.0 / norm_1, 3.0 / norm_1];

let norm_2 = (0.0_f32.powi(2) + -1.0_f32.powi(2) + 6.0_f32.powi(2)).sqrt();
let data_2 = vec![0.0 / norm_2, -1.0 / norm_2, 6.0 / norm_2];

let input = BruteForceKnnOperatorInput {
data: vec![vec![0.0, 1.0, 0.0], data_1.clone(), data_2.clone()],
query: vec![0.0, 1.0, 0.0],
k: 2,
distance_metric: DistanceFunction::InnerProduct,
};
let output = operator.run(&input).await.unwrap();

assert_eq!(output.indices, vec![0, 1]);
let expected_distance_1 =
1.0f32 - ((data_1[0] * 0.0) + (data_1[1] * 1.0) + (data_1[2] * 0.0));
assert_eq!(output.distances, vec![0.0, expected_distance_1]);
}
}
2 changes: 2 additions & 0 deletions rust/worker/src/execution/operators/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
pub(super) mod brute_force_knn;
pub(super) mod normalize_vectors;
pub(super) mod pull_log;
81 changes: 81 additions & 0 deletions rust/worker/src/execution/operators/normalize_vectors.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
use crate::execution::operator::Operator;
use async_trait::async_trait;

const EPS: f32 = 1e-30;

#[derive(Debug)]
pub struct NormalizeVectorOperator {}

pub struct NormalizeVectorOperatorInput {
pub vectors: Vec<Vec<f32>>,
}

pub struct NormalizeVectorOperatorOutput {
pub normalized_vectors: Vec<Vec<f32>>,
}

#[async_trait]
impl Operator<NormalizeVectorOperatorInput, NormalizeVectorOperatorOutput>
for NormalizeVectorOperator
{
type Error = ();

async fn run(
&self,
input: &NormalizeVectorOperatorInput,
) -> Result<NormalizeVectorOperatorOutput, Self::Error> {
// TODO: this should not have to reallocate the vectors. We can optimize this later.
let mut normalized_vectors = Vec::with_capacity(input.vectors.len());
for vector in &input.vectors {
let mut norm = 0.0;
for x in vector {
norm += x * x;
}
let norm = 1.0 / (norm.sqrt() + EPS);
let normalized_vector = vector.iter().map(|x| x * norm).collect();
normalized_vectors.push(normalized_vector);
}
Ok(NormalizeVectorOperatorOutput { normalized_vectors })
}
}

#[cfg(test)]
mod tests {
use super::*;

const COMPARE_EPS: f32 = 1e-9;
fn float_eps_eq(a: &[f32], b: &[f32]) -> bool {
a.iter()
.zip(b.iter())
.all(|(a, b)| (a - b).abs() < COMPARE_EPS)
}

#[tokio::test]
async fn test_normalize_vector() {
let operator = NormalizeVectorOperator {};
let input = NormalizeVectorOperatorInput {
vectors: vec![
vec![1.0, 2.0, 3.0],
vec![4.0, 5.0, 6.0],
vec![7.0, 8.0, 9.0],
],
};

let output = operator.run(&input).await.unwrap();
let expected_output = NormalizeVectorOperatorOutput {
normalized_vectors: vec![
vec![0.26726124, 0.5345225, 0.8017837],
vec![0.45584232, 0.5698029, 0.6837635],
vec![0.5025707, 0.5743665, 0.64616233],
],
};

for (a, b) in output
.normalized_vectors
.iter()
.zip(expected_output.normalized_vectors.iter())
{
assert!(float_eps_eq(a, b), "{:?} != {:?}", a, b);
}
}
}
2 changes: 1 addition & 1 deletion rust/worker/src/index/hnsw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ extern "C" {
pub mod test {
use super::*;

use crate::index::types::DistanceFunction;
use crate::distance::DistanceFunction;
use crate::index::utils;
use rand::Rng;
use rayon::prelude::*;
Expand Down
Loading

0 comments on commit f5e173d

Please sign in to comment.