Skip to content

Commit

Permalink
feat: vector search with distance range (#3326)
Browse files Browse the repository at this point in the history
  • Loading branch information
BubbleCal authored Jan 3, 2025
1 parent 8585207 commit 39f12dc
Show file tree
Hide file tree
Showing 8 changed files with 301 additions and 58 deletions.
2 changes: 2 additions & 0 deletions java/core/lance-jni/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ pub fn get_query(env: &mut JNIEnv, query_obj: JObject) -> Result<Option<Query>>
column,
key,
k,
lower_bound: None,
upper_bound: None,
nprobes,
ef,
refine_factor,
Expand Down
6 changes: 6 additions & 0 deletions rust/lance-index/src/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ pub struct Query {
/// Top k results to return.
pub k: usize,

/// The lower bound (inclusive) of the distance to be searched.
pub lower_bound: Option<f32>,

/// The upper bound (exclusive) of the distance to be searched.
pub upper_bound: Option<f32>,

/// The number of probes to load and search.
pub nprobes: usize,

Expand Down
80 changes: 46 additions & 34 deletions rust/lance-index/src/vector/flat/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ use arrow::array::AsArray;
use arrow_array::{Array, ArrayRef, Float32Array, RecordBatch, UInt64Array};
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use deepsize::DeepSizeOf;
use itertools::Itertools;
use lance_core::{Error, Result, ROW_ID_FIELD};
use lance_file::reader::FileReader;
use lance_linalg::distance::DistanceType;
Expand Down Expand Up @@ -44,11 +43,17 @@ lazy_static::lazy_static! {
}

#[derive(Default)]
pub struct FlatQueryParams {}
pub struct FlatQueryParams {
lower_bound: Option<f32>,
upper_bound: Option<f32>,
}

impl From<&Query> for FlatQueryParams {
fn from(_: &Query) -> Self {
Self {}
fn from(q: &Query) -> Self {
Self {
lower_bound: q.lower_bound,
upper_bound: q.upper_bound,
}
}
}

Expand All @@ -72,50 +77,57 @@ impl IvfSubIndex for FlatIndex {
&self,
query: ArrayRef,
k: usize,
_params: Self::QueryParams,
params: Self::QueryParams,
storage: &impl VectorStore,
prefilter: Arc<dyn PreFilter>,
) -> Result<RecordBatch> {
let dist_calc = storage.dist_calculator(query);

let (row_ids, dists): (Vec<u64>, Vec<f32>) = match prefilter.is_empty() {
true => dist_calc
.distance_all()
.into_iter()
.zip(0..storage.len() as u32)
.map(|(dist, id)| OrderedNode {
id,
dist: OrderedFloat(dist),
})
.sorted_unstable()
.take(k)
.map(
|OrderedNode {
id,
dist: OrderedFloat(dist),
}| (storage.row_id(id), dist),
)
.unzip(),
let mut res: Vec<_> = match prefilter.is_empty() {
true => {
let iter = dist_calc
.distance_all()
.into_iter()
.zip(0..storage.len() as u32)
.map(|(dist, id)| OrderedNode {
id,
dist: OrderedFloat(dist),
});

if params.lower_bound.is_some() || params.upper_bound.is_some() {
let lower_bound = params.lower_bound.unwrap_or(f32::MIN);
let upper_bound = params.upper_bound.unwrap_or(f32::MAX);
iter.filter(|r| lower_bound <= r.dist.0 && r.dist.0 < upper_bound)
.collect()
} else {
iter.collect()
}
}
false => {
let row_id_mask = prefilter.mask();
(0..storage.len())
let iter = (0..storage.len())
.filter(|&id| row_id_mask.selected(storage.row_id(id as u32)))
.map(|id| OrderedNode {
id: id as u32,
dist: OrderedFloat(dist_calc.distance(id as u32)),
})
.sorted_unstable()
.take(k)
.map(
|OrderedNode {
id,
dist: OrderedFloat(dist),
}| (storage.row_id(id), dist),
)
.unzip()
});
if params.lower_bound.is_some() || params.upper_bound.is_some() {
let lower_bound = params.lower_bound.unwrap_or(f32::MIN);
let upper_bound = params.upper_bound.unwrap_or(f32::MAX);
iter.filter(|r| lower_bound <= r.dist.0 && r.dist.0 < upper_bound)
.collect()
} else {
iter.collect()
}
}
};
res.sort_unstable();

let (row_ids, dists): (Vec<_>, Vec<_>) = res
.into_iter()
.take(k)
.map(|r| (storage.row_id(r.id), r.dist.0))
.unzip();
let (row_ids, dists) = (UInt64Array::from(row_ids), Float32Array::from(dists));

Ok(RecordBatch::try_new(
Expand Down
63 changes: 61 additions & 2 deletions rust/lance/src/dataset/scanner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ use datafusion::physical_plan::{
ExecutionPlan, SendableRecordBatchStream,
};
use datafusion::scalar::ScalarValue;
use datafusion_expr::Operator;
use datafusion_physical_expr::aggregate::AggregateExprBuilder;
use datafusion_physical_expr::{Partitioning, PhysicalExpr};
use futures::future::BoxFuture;
Expand Down Expand Up @@ -705,6 +706,8 @@ impl Scanner {
column: column.to_string(),
key: key.into(),
k,
lower_bound: None,
upper_bound: None,
nprobes: 1,
ef: None,
refine_factor: None,
Expand All @@ -714,6 +717,19 @@ impl Scanner {
Ok(self)
}

/// Set the distance thresholds for the nearest neighbor search.
pub fn distance_range(
&mut self,
lower_bound: Option<f32>,
upper_bound: Option<f32>,
) -> &mut Self {
if let Some(q) = self.nearest.as_mut() {
q.lower_bound = lower_bound;
q.upper_bound = upper_bound;
}
self
}

pub fn nprobs(&mut self, n: usize) -> &mut Self {
if let Some(q) = self.nearest.as_mut() {
q.nprobes = n;
Expand Down Expand Up @@ -1994,16 +2010,59 @@ impl Scanner {
q.metric_type,
)?);

// filter out elements out of distance range
let lower_bound_expr = q
.lower_bound
.map(|v| {
let lower_bound = expressions::lit(v);
expressions::binary(
expressions::col(DIST_COL, flat_dist.schema().as_ref())?,
Operator::GtEq,
lower_bound,
flat_dist.schema().as_ref(),
)
})
.transpose()?;
let upper_bound_expr = q
.upper_bound
.map(|v| {
let upper_bound = expressions::lit(v);
expressions::binary(
expressions::col(DIST_COL, flat_dist.schema().as_ref())?,
Operator::Lt,
upper_bound,
flat_dist.schema().as_ref(),
)
})
.transpose()?;
let filter_expr = match (lower_bound_expr, upper_bound_expr) {
(Some(lower), Some(upper)) => Some(expressions::binary(
lower,
Operator::And,
upper,
flat_dist.schema().as_ref(),
)?),
(Some(lower), None) => Some(lower),
(None, Some(upper)) => Some(upper),
(None, None) => None,
};

let knn_plan: Arc<dyn ExecutionPlan> = if let Some(filter_expr) = filter_expr {
Arc::new(FilterExec::try_new(filter_expr, flat_dist)?)
} else {
flat_dist
};

// Use DataFusion's [SortExec] for Top-K search
let sort = SortExec::new(
vec![PhysicalSortExpr {
expr: expressions::col(DIST_COL, flat_dist.schema().as_ref())?,
expr: expressions::col(DIST_COL, knn_plan.schema().as_ref())?,
options: SortOptions {
descending: false,
nulls_first: false,
},
}],
flat_dist,
knn_plan,
)
.with_fetch(Some(q.k));

Expand Down
2 changes: 2 additions & 0 deletions rust/lance/src/index/vector/fixture_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,8 @@ mod test {
column: "test".to_string(),
key: Arc::new(Float32Array::from(query)),
k: 1,
lower_bound: None,
upper_bound: None,
nprobes: 1,
ef: None,
refine_factor: None,
Expand Down
2 changes: 2 additions & 0 deletions rust/lance/src/index/vector/ivf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1978,6 +1978,8 @@ mod tests {
column: Self::COLUMN.to_string(),
key: Arc::new(row),
k: 5,
lower_bound: None,
upper_bound: None,
nprobes: 1,
ef: None,
refine_factor: None,
Expand Down
Loading

0 comments on commit 39f12dc

Please sign in to comment.