From f5980e956e3d417647d6fc19f237490f88fa02d6 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Wed, 10 Apr 2024 12:34:28 +0800 Subject: [PATCH] perf: impl heuristic pruning for HNSW (#2171) improve the recall from 90% to 99%, according to the HNSW paper, this also improve performance, especially for highly clustered dataset --------- Signed-off-by: BubbleCal --- rust/lance-index/benches/hnsw.rs | 2 +- rust/lance-index/src/vector/graph.rs | 36 +++--- rust/lance-index/src/vector/graph/builder.rs | 34 +++--- rust/lance-index/src/vector/graph/memory.rs | 4 + rust/lance-index/src/vector/graph/storage.rs | 4 + rust/lance-index/src/vector/hnsw.rs | 112 ++++++++++++------- rust/lance-index/src/vector/hnsw/builder.rs | 15 +-- rust/lance-index/src/vector/pq/storage.rs | 4 + rust/lance-index/src/vector/sq/storage.rs | 29 ++--- rust/lance/examples/hnsw.rs | 2 +- rust/lance/src/index/vector/hnsw.rs | 6 +- 11 files changed, 154 insertions(+), 94 deletions(-) diff --git a/rust/lance-index/benches/hnsw.rs b/rust/lance-index/benches/hnsw.rs index b69f52cd9f..8f6b301f9c 100644 --- a/rust/lance-index/benches/hnsw.rs +++ b/rust/lance-index/benches/hnsw.rs @@ -55,7 +55,7 @@ fn bench_hnsw(c: &mut Criterion) { .search(query, K, 300, None) .unwrap() .iter() - .map(|(i, _)| *i) + .map(|node| node.id) .collect(); assert_eq!(uids.len(), K); diff --git a/rust/lance-index/src/vector/graph.rs b/rust/lance-index/src/vector/graph.rs index fee2b4891b..85f6354e26 100644 --- a/rust/lance-index/src/vector/graph.rs +++ b/rust/lance-index/src/vector/graph.rs @@ -63,7 +63,7 @@ impl From for GraphNode { /// A wrapper for f32 to make it ordered, so that we can put it into /// a BTree or Heap #[derive(Debug, PartialEq, Clone, Copy)] -pub(crate) struct OrderedFloat(pub f32); +pub struct OrderedFloat(pub f32); impl PartialOrd for OrderedFloat { fn partial_cmp(&self, other: &Self) -> Option { @@ -92,11 +92,17 @@ impl From for f32 { } #[derive(Debug, Eq, PartialEq, Clone)] -pub(crate) struct OrderedNode { +pub struct OrderedNode { pub id: u32, pub dist: OrderedFloat, } +impl OrderedNode { + pub fn new(id: u32, dist: OrderedFloat) -> Self { + Self { id, dist } + } +} + impl PartialOrd for OrderedNode { fn partial_cmp(&self, other: &Self) -> Option { Some(self.dist.cmp(&other.dist)) @@ -181,37 +187,37 @@ pub(super) fn beam_search( k: usize, dist_calc: Option>, bitset: Option<&roaring::bitmap::RoaringBitmap>, -) -> Result> { +) -> Result> { let mut visited: HashSet<_> = start.iter().copied().collect(); let dist_calc = dist_calc.unwrap_or_else(|| graph.storage().dist_calculator(query).into()); - let mut candidates: BinaryHeap> = dist_calc + let mut candidates: BinaryHeap> = dist_calc .distance(start) .iter() .zip(start) - .map(|(&dist, id)| Reverse((dist.into(), *id))) + .map(|(&dist, id)| Reverse((dist.into(), *id).into())) .collect(); - let mut results: BinaryHeap<(OrderedFloat, u32)> = candidates + let mut results: BinaryHeap = candidates .clone() .into_iter() .filter(|node| { bitset - .map(|bitset| bitset.contains(node.0 .1)) + .map(|bitset| bitset.contains(node.0.id)) .unwrap_or(true) }) .map(|v| v.0) .collect(); while !candidates.is_empty() { - let (dist, current) = candidates.pop().expect("candidates is empty").0; + let current = candidates.pop().expect("candidates is empty").0; let furthest = results .peek() - .map(|kv| kv.0) + .map(|node| node.dist) .unwrap_or(OrderedFloat(f32::INFINITY)); - if dist > furthest { + if current.dist > furthest { break; } - let neighbors = graph.neighbors(current).ok_or_else(|| Error::Index { - message: format!("Node {} does not exist in the graph", current), + let neighbors = graph.neighbors(current.id).ok_or_else(|| Error::Index { + message: format!("Node {} does not exist in the graph", current.id), location: location!(), })?; @@ -222,7 +228,7 @@ pub(super) fn beam_search( visited.insert(neighbor); let furthest = results .peek() - .map(|kv| kv.0) + .map(|node| node.dist) .unwrap_or(OrderedFloat(f32::INFINITY)); let dist = dist_calc.distance(&[neighbor])[0].into(); if dist <= furthest || results.len() < k { @@ -230,12 +236,12 @@ pub(super) fn beam_search( .map(|bitset| bitset.contains(neighbor)) .unwrap_or(true) { - results.push((dist, neighbor)); + results.push((dist, neighbor).into()); if results.len() > k { results.pop(); } } - candidates.push(Reverse((dist, neighbor))); + candidates.push(Reverse((dist, neighbor).into())); } } } diff --git a/rust/lance-index/src/vector/graph/builder.rs b/rust/lance-index/src/vector/graph/builder.rs index 941f3b1a2c..a83c070322 100644 --- a/rust/lance-index/src/vector/graph/builder.rs +++ b/rust/lance-index/src/vector/graph/builder.rs @@ -21,6 +21,7 @@ use snafu::{location, Location}; use super::OrderedNode; use super::{memory::InMemoryVectorStorage, Graph, GraphNode, OrderedFloat}; use crate::vector::graph::storage::VectorStorage; +use crate::vector::hnsw::select_neighbors_heuristic; /// GraphNode during build. #[derive(Debug, Clone)] @@ -43,15 +44,6 @@ impl GraphBuilderNode { fn add_neighbor(&mut self, distance: OrderedFloat, id: u32) { self.neighbors.push(OrderedNode { dist: distance, id }); } - - /// Prune the node and only keep `max_edges` edges. - /// - /// Returns the ids of pruned neighbors. - fn prune(&mut self, max_edges: usize) { - while self.neighbors.len() > max_edges { - self.neighbors.pop(); - } - } } impl From<&GraphBuilderNode> for GraphNode { @@ -99,7 +91,7 @@ impl Graph for GraphBuilder { } fn storage(&self) -> Arc { - self.vectors.clone() as Arc + self.vectors.clone() } } @@ -140,11 +132,23 @@ impl GraphBuilder { } pub fn prune(&mut self, node: u32, max_edges: usize) -> Result<()> { - let node = self.nodes.get_mut(&node).ok_or_else(|| Error::Index { - message: format!("Node {} not found", node), - location: location!(), - })?; - node.prune(max_edges); + let vector = self.vectors.vector(node); + + let neighbors = &self + .nodes + .get(&node) + .ok_or_else(|| Error::Index { + message: format!("Node {} not found", node), + location: location!(), + })? + .neighbors; + + let pruned_neighbors = + select_neighbors_heuristic(self, vector, neighbors, max_edges, false).collect(); + + self.nodes + .entry(node) + .and_modify(|node| node.neighbors = pruned_neighbors); Ok(()) } diff --git a/rust/lance-index/src/vector/graph/memory.rs b/rust/lance-index/src/vector/graph/memory.rs index 49838f0611..9211f6c8dd 100644 --- a/rust/lance-index/src/vector/graph/memory.rs +++ b/rust/lance-index/src/vector/graph/memory.rs @@ -50,6 +50,10 @@ impl InMemoryVectorStorage { } impl VectorStorage for InMemoryVectorStorage { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn len(&self) -> usize { self.vectors.num_rows() } diff --git a/rust/lance-index/src/vector/graph/storage.rs b/rust/lance-index/src/vector/graph/storage.rs index a5d82114e1..8f75c7253d 100644 --- a/rust/lance-index/src/vector/graph/storage.rs +++ b/rust/lance-index/src/vector/graph/storage.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::any::Any; + use lance_linalg::distance::MetricType; pub trait DistCalculator { @@ -26,6 +28,8 @@ pub trait DistCalculator { /// /// TODO: should we rename this to "VectorDistance"?; pub trait VectorStorage: Send + Sync { + fn as_any(&self) -> &dyn Any; + fn len(&self) -> usize; /// Returns true if this graph is empty. diff --git a/rust/lance-index/src/vector/hnsw.rs b/rust/lance-index/src/vector/hnsw.rs index a53d1306a1..13f2629014 100644 --- a/rust/lance-index/src/vector/hnsw.rs +++ b/rust/lance-index/src/vector/hnsw.rs @@ -17,6 +17,7 @@ //! Hierarchical Navigable Small World (HNSW). //! +use std::cmp::Reverse; use std::collections::{BinaryHeap, HashMap, HashSet}; use std::fmt::Debug; use std::ops::Range; @@ -44,6 +45,8 @@ use snafu::{location, Location}; use self::builder::HNSW_METADATA_KEY; +use super::graph::memory::InMemoryVectorStorage; +use super::graph::OrderedNode; use super::graph::{ builder::GraphBuilder, greedy_search, @@ -352,7 +355,7 @@ impl HNSW { k: usize, ef: usize, bitset: Option, - ) -> Result> { + ) -> Result> { let mut ep = self.entry_point; let num_layers = self.levels.len(); @@ -371,9 +374,7 @@ impl HNSW { Some(dist_calc), bitset.as_ref(), )?; - Ok(select_neighbors(&candidates, k) - .map(|(d, u)| (u, d.into())) - .collect()) + Ok(select_neighbors(&candidates, k).cloned().collect()) } /// Returns the metadata of this [`HNSW`]. @@ -477,40 +478,38 @@ impl HNSW { /// /// Algorithm 3 in the HNSW paper. fn select_neighbors( - orderd_candidates: &BinaryHeap<(OrderedFloat, u32)>, + orderd_candidates: &BinaryHeap, k: usize, -) -> impl Iterator + '_ { - orderd_candidates - .iter() - .sorted() - .take(k) - .map(|(d, u)| (*d, *u)) +) -> impl Iterator + '_ { + orderd_candidates.iter().sorted().take(k) } /// Algorithm 4 in the HNSW paper. /// -/// -/// Modifies to the original algorithm: -/// 1. Do not use keepPrunedConnections, we use a heap to capture nearest neighbors. -fn select_neighbors_heuristic( - graph: &dyn Graph, +/// NOTE: the result is not ordered +pub(crate) fn select_neighbors_heuristic( + graph: &GraphBuilder, query: &[f32], - orderd_candidates: &BinaryHeap<(OrderedFloat, u32)>, + orderd_candidates: &BinaryHeap, k: usize, extend_candidates: bool, -) -> impl Iterator { - let mut heap = orderd_candidates.clone(); +) -> impl Iterator { + let mut w = orderd_candidates + .iter() + .cloned() + .map(Reverse) + .collect::>(); if extend_candidates { let dist_calc = graph.storage().dist_calculator(query); - let mut visited = HashSet::with_capacity(orderd_candidates.len() * 64); - visited.extend(orderd_candidates.iter().map(|(_, u)| *u)); - orderd_candidates.iter().sorted().rev().for_each(|(_, u)| { - if let Some(neighbors) = graph.neighbors(*u) { + let mut visited = HashSet::with_capacity(orderd_candidates.len() * k); + visited.extend(orderd_candidates.iter().map(|node| node.id)); + orderd_candidates.iter().sorted().rev().for_each(|node| { + if let Some(neighbors) = graph.neighbors(node.id) { neighbors.for_each(|n| { if !visited.contains(&n) { let d: OrderedFloat = dist_calc.distance(&[n])[0].into(); - heap.push((d, n)); + w.push(Reverse((d, n).into())); } visited.insert(n); }); @@ -518,14 +517,38 @@ fn select_neighbors_heuristic( }); } - heap.into_sorted_vec().into_iter().take(k) + let mut results: Vec = Vec::with_capacity(k); + let mut discarded = Vec::new(); + let storage = graph.storage(); + let storage = storage + .as_any() + .downcast_ref::() + .unwrap(); + while !w.is_empty() && results.len() < k { + let u = w.pop().unwrap().0; + + if results.is_empty() + || results + .iter() + .all(|v| u.dist < OrderedFloat(storage.distance_between(u.id, v.id))) + { + results.push(u); + } else { + discarded.push(u); + } + } + + while results.len() < k && !discarded.is_empty() { + results.push(discarded.pop().unwrap()); + } + + results.into_iter() } #[cfg(test)] mod tests { use super::*; - use crate::vector::graph::memory::InMemoryVectorStorage; use arrow_array::types::Float32Type; use lance_linalg::matrix::MatrixView; use lance_testing::datagen::generate_random_array; @@ -533,29 +556,38 @@ mod tests { #[test] fn test_select_neighbors() { - let candidates: BinaryHeap<(OrderedFloat, u32)> = - (1..6).map(|i| (OrderedFloat(i as f32), i)).collect(); + let candidates: BinaryHeap = + (1..6).map(|i| (OrderedFloat(i as f32), i).into()).collect(); - let result = select_neighbors(&candidates, 3).collect::>(); + let result = select_neighbors(&candidates, 3) + .cloned() + .collect::>(); assert_eq!( result, vec![ - (OrderedFloat(1.0), 1), - (OrderedFloat(2.0), 2), - (OrderedFloat(3.0), 3), + OrderedNode::new(1, OrderedFloat(1.0)), + OrderedNode::new(2, OrderedFloat(2.0)), + OrderedNode::new(3, OrderedFloat(3.0)), ] ); - assert_eq!(select_neighbors(&candidates, 0).collect::>(), vec![]); + assert_eq!( + select_neighbors(&candidates, 0) + .cloned() + .collect::>(), + vec![] + ); assert_eq!( - select_neighbors(&candidates, 8).collect::>(), + select_neighbors(&candidates, 8) + .cloned() + .collect::>(), vec![ - (OrderedFloat(1.0), 1), - (OrderedFloat(2.0), 2), - (OrderedFloat(3.0), 3), - (OrderedFloat(4.0), 4), - (OrderedFloat(5.0), 5), + OrderedNode::new(1, OrderedFloat(1.0)), + OrderedNode::new(2, OrderedFloat(2.0)), + OrderedNode::new(3, OrderedFloat(3.0)), + OrderedNode::new(4, OrderedFloat(4.0)), + OrderedNode::new(5, OrderedFloat(5.0)), ] ); } @@ -632,7 +664,7 @@ mod tests { .search(q, K, 128, None) .unwrap() .iter() - .map(|(i, _)| *i) + .map(|node| node.id) .collect(); let gt = ground_truth(&mat, q, K); let recall = results.intersection(>).count() as f32 / K as f32; diff --git a/rust/lance-index/src/vector/hnsw/builder.rs b/rust/lance-index/src/vector/hnsw/builder.rs index 96d511ab63..bf1d927991 100644 --- a/rust/lance-index/src/vector/hnsw/builder.rs +++ b/rust/lance-index/src/vector/hnsw/builder.rs @@ -220,17 +220,18 @@ impl HNSWBuilder { ) .collect() } else { - select_neighbors(&candidates, self.params.m).collect() + select_neighbors(&candidates, self.params.m) + .cloned() + .collect() }; - for (distance, nb) in neighbors.iter() { - cur_level.connect(node, *nb, Some(*distance))?; + for neighbor in neighbors.iter() { + cur_level.connect(node, neighbor.id, Some(neighbor.dist))?; } - for (_, nb) in neighbors { - cur_level.prune(nb, self.params.m_max)?; + for neighbor in neighbors { + cur_level.prune(neighbor.id, self.params.m_max)?; } - cur_level.prune(node, self.params.m_max)?; - ep = candidates.iter().map(|(_, node)| *node).collect(); + ep = candidates.iter().map(|node| node.id).collect(); } if level > self.levels.len() as u16 { diff --git a/rust/lance-index/src/vector/pq/storage.rs b/rust/lance-index/src/vector/pq/storage.rs index 51a8c19383..ee38d54bb3 100644 --- a/rust/lance-index/src/vector/pq/storage.rs +++ b/rust/lance-index/src/vector/pq/storage.rs @@ -405,6 +405,10 @@ impl QuantizerStorage for ProductQuantizationStorage { } impl VectorStorage for ProductQuantizationStorage { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn len(&self) -> usize { self.batch.num_rows() } diff --git a/rust/lance-index/src/vector/sq/storage.rs b/rust/lance-index/src/vector/sq/storage.rs index 58cd7e8f2a..c28007e454 100644 --- a/rust/lance-index/src/vector/sq/storage.rs +++ b/rust/lance-index/src/vector/sq/storage.rs @@ -202,6 +202,10 @@ impl QuantizerStorage for ScalarQuantizationStorage { } impl VectorStorage for ScalarQuantizationStorage { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn len(&self) -> usize { self.batch.num_rows() } @@ -247,27 +251,26 @@ impl SQDistCalculator { sq_codes, } } - - fn get_sq_code(&self, id: u32) -> &[u8] { - let dim = self.sq_codes.value_length() as usize; - let values: &[u8] = self - .sq_codes - .values() - .as_any() - .downcast_ref::() - .unwrap() - .values(); - &values[id as usize * dim..(id as usize + 1) * dim] - } } impl DistCalculator for SQDistCalculator { fn distance(&self, ids: &[u32]) -> Vec { ids.iter() .map(|&id| { - let sq_code = self.get_sq_code(id); + let sq_code = get_sq_code(&self.sq_codes, id); l2_distance_uint_scalar(sq_code, &self.query_sq_code) }) .collect() } } + +fn get_sq_code(sq_codes: &FixedSizeListArray, id: u32) -> &[u8] { + let dim = sq_codes.value_length() as usize; + let values: &[u8] = sq_codes + .values() + .as_any() + .downcast_ref::() + .unwrap() + .values(); + &values[id as usize * dim..(id as usize + 1) * dim] +} diff --git a/rust/lance/examples/hnsw.rs b/rust/lance/examples/hnsw.rs index ff5572a85f..cf00156009 100644 --- a/rust/lance/examples/hnsw.rs +++ b/rust/lance/examples/hnsw.rs @@ -111,7 +111,7 @@ async fn main() { .search(q, k, args.ef, None) .unwrap() .iter() - .map(|(i, _)| *i) + .map(|node| node.id) .collect(); let search_time = now.elapsed().as_micros(); println!( diff --git a/rust/lance/src/index/vector/hnsw.rs b/rust/lance/src/index/vector/hnsw.rs index da8384448e..b2bfbc85ff 100644 --- a/rust/lance/src/index/vector/hnsw.rs +++ b/rust/lance/src/index/vector/hnsw.rs @@ -180,8 +180,10 @@ impl VectorIndex for HNSWIndex { bitmap, )?; - let row_ids = UInt64Array::from_iter_values(results.iter().map(|x| row_ids[x.0 as usize])); - let distances = Arc::new(Float32Array::from_iter_values(results.iter().map(|x| x.1))); + let row_ids = UInt64Array::from_iter_values(results.iter().map(|x| row_ids[x.id as usize])); + let distances = Arc::new(Float32Array::from_iter_values( + results.iter().map(|x| x.dist.0), + )); let schema = Arc::new(arrow_schema::Schema::new(vec![ arrow_schema::Field::new(DIST_COL, DataType::Float32, true),