Skip to content

Commit

Permalink
feat: support float16/float64 for multivector
Browse files Browse the repository at this point in the history
Signed-off-by: BubbleCal <[email protected]>
  • Loading branch information
BubbleCal committed Jan 15, 2025
1 parent b572905 commit f4cf1a1
Showing 1 changed file with 59 additions and 18 deletions.
77 changes: 59 additions & 18 deletions rust/lance-linalg/src/distance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
use std::sync::Arc;

use arrow_array::cast::AsArray;
use arrow_array::types::{Float32Type, UInt8Type};
use arrow_array::{Array, FixedSizeListArray, Float32Array, ListArray};
use arrow_array::types::{Float16Type, Float32Type, Float64Type, UInt8Type};
use arrow_array::{Array, ArrowPrimitiveType, FixedSizeListArray, Float32Array, ListArray};
use arrow_schema::{ArrowError, DataType};

pub mod cosine;
Expand Down Expand Up @@ -117,6 +117,17 @@ pub fn multivec_distance(
));
};

// check the query vectors type first
// because we don't want to check the vectors type for each vector
match query.data_type() {
DataType::Float16 | DataType::Float32 | DataType::Float64 | DataType::UInt8 => {}
_ => {
return Err(ArrowError::InvalidArgumentError(
"query must be a float array or binary array".to_string(),
));
}
}

let dists = vectors
.iter()
.map(|v| {
Expand All @@ -139,26 +150,56 @@ pub fn multivec_distance(
})
.sum()
}
_ => {
let query = query.as_primitive::<Float32Type>().values();
query
.chunks_exact(dim)
.map(|q| {
multivector
.values()
.as_primitive::<Float32Type>()
.values()
.chunks_exact(dim)
.map(|v| distance_type.func()(q, v))
.min_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap()
})
.sum()
}
_ => match query.data_type() {
DataType::Float16 => multivec_distance_impl::<Float16Type>(
query,
multivector,
dim,
distance_type,
),
DataType::Float32 => multivec_distance_impl::<Float32Type>(
query,
multivector,
dim,
distance_type,
),
DataType::Float64 => multivec_distance_impl::<Float64Type>(
query,
multivector,
dim,
distance_type,
),
_ => unreachable!("missed to check query type"),
},
}
})
.unwrap_or(f32::NAN)
})
.collect();
Ok(dists)
}

fn multivec_distance_impl<T: ArrowPrimitiveType>(
query: &dyn Array,
multivector: &FixedSizeListArray,
dim: usize,
distance_type: DistanceType,
) -> f32
where
T::Native: L2 + Cosine + Dot,
{
let query = query.as_primitive::<T>().values();
query
.chunks_exact(dim)
.map(|q| {
multivector
.values()
.as_primitive::<T>()
.values()
.chunks_exact(dim)
.map(|v| distance_type.func()(q, v))
.min_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap()
})
.sum()
}

0 comments on commit f4cf1a1

Please sign in to comment.