diff --git a/rust/lance-index/Cargo.toml b/rust/lance-index/Cargo.toml index d103d983a9..74d437f1ef 100644 --- a/rust/lance-index/Cargo.toml +++ b/rust/lance-index/Cargo.toml @@ -19,17 +19,18 @@ arrow-schema.workspace = true arrow-select.workspace = true async-recursion.workspace = true async-trait.workspace = true -datafusion.workspace = true datafusion-common.workspace = true datafusion-expr.workspace = true datafusion-physical-expr.workspace = true datafusion-sql.workspace = true +datafusion.workspace = true futures.workspace = true half.workspace = true +itertools.workspace = true lance-arrow.workspace = true lance-core.workspace = true -lance-file.workspace = true lance-datafusion.workspace = true +lance-file.workspace = true lance-io.workspace = true lance-linalg.workspace = true lance-table.workspace = true diff --git a/rust/lance-index/src/vector/graph.rs b/rust/lance-index/src/vector/graph.rs index 955e7d8cb9..cdd0e07de3 100644 --- a/rust/lance-index/src/vector/graph.rs +++ b/rust/lance-index/src/vector/graph.rs @@ -26,7 +26,8 @@ pub(crate) mod builder; pub mod memory; pub(super) mod storage; -use storage::VectorStorage; +/// Vector storage to back a graph. +pub use storage::VectorStorage; pub(crate) const NEIGHBORS_COL: &str = "__neighbors"; @@ -87,6 +88,36 @@ impl From for f32 { } } +#[derive(Debug, Eq, PartialEq, Clone)] +pub(crate) struct OrderedNode { + pub id: u32, + pub dist: OrderedFloat, +} + +impl PartialOrd for OrderedNode { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.dist.cmp(&other.dist)) + } +} + +impl Ord for OrderedNode { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.dist.cmp(&other.dist) + } +} + +impl From<(OrderedFloat, u32)> for OrderedNode { + fn from((dist, id): (OrderedFloat, u32)) -> Self { + Self { id, dist } + } +} + +impl From for (OrderedFloat, u32) { + fn from(node: OrderedNode) -> Self { + (node.dist, node.id) + } +} + /// Distance calculator. /// /// This trait is used to calculate a query vector to a stream of vector IDs. @@ -113,7 +144,7 @@ pub trait Graph { } /// Get the neighbors of a graph node, identifyied by the index. - fn neighbors(&self, key: u32) -> Option + '_>>; + fn neighbors(&self, key: u32) -> Option + '_>>; /// Access to underline storage fn storage(&self) -> Arc; @@ -163,7 +194,7 @@ pub(super) fn beam_search( location: location!(), })?; - for &neighbor in neighbors { + for neighbor in neighbors { if visited.contains(&neighbor) { continue; } diff --git a/rust/lance-index/src/vector/graph/builder.rs b/rust/lance-index/src/vector/graph/builder.rs index 521dec2358..67c21127a4 100644 --- a/rust/lance-index/src/vector/graph/builder.rs +++ b/rust/lance-index/src/vector/graph/builder.rs @@ -12,12 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::BTreeMap; +use std::collections::{BinaryHeap, HashMap}; use std::sync::Arc; use lance_core::{Error, Result}; use snafu::{location, Location}; +use super::OrderedNode; use super::{memory::InMemoryVectorStorage, Graph, GraphNode, OrderedFloat}; use crate::vector::graph::storage::VectorStorage; @@ -26,44 +27,48 @@ use crate::vector::graph::storage::VectorStorage; pub struct GraphBuilderNode { /// Node ID pub(crate) id: u32, - /// Neighbors, sorted by the distance. - pub(crate) neighbors: BTreeMap, - /// Pointer to the next level of graph, or acts as the idx - pub pointer: u32, + /// Neighbors, sorted by the distance. + pub(crate) neighbors: BinaryHeap, } impl GraphBuilderNode { fn new(id: u32) -> Self { Self { id, - neighbors: BTreeMap::new(), - pointer: 0, + neighbors: BinaryHeap::new(), } } + fn add_neighbor(&mut self, distance: f32, id: u32) { + self.neighbors.push(OrderedNode { + dist: OrderedFloat(distance), + id, + }); + } + /// Prune the node and only keep `max_edges` edges. /// /// Returns the ids of pruned neighbors. - fn prune(&mut self, max_edges: usize) -> Vec { - if self.neighbors.len() <= max_edges { - return vec![]; - } - - let mut pruned = Vec::with_capacity(self.neighbors.len() - max_edges); + fn prune(&mut self, max_edges: usize) { while self.neighbors.len() > max_edges { - let (_, node) = self.neighbors.pop_last().unwrap(); - pruned.push(node) + self.neighbors.pop(); } - pruned } } impl From<&GraphBuilderNode> for GraphNode { fn from(node: &GraphBuilderNode) -> Self { + let neighbors = node + .neighbors + .clone() + .into_sorted_vec() + .into_iter() + .map(|n| n.id) + .collect::>(); Self { id: node.id, - neighbors: node.neighbors.values().copied().collect(), + neighbors, } } } @@ -73,7 +78,7 @@ impl From<&GraphBuilderNode> for GraphNode { /// [GraphBuilder] is used to build a graph in memory. /// pub struct GraphBuilder { - pub(crate) nodes: BTreeMap, + pub(crate) nodes: HashMap, /// Storage for vectors. vectors: Arc, @@ -84,9 +89,15 @@ impl Graph for GraphBuilder { self.nodes.len() } - fn neighbors(&self, key: u32) -> Option + '_>> { + fn neighbors(&self, key: u32) -> Option + '_>> { let node = self.nodes.get(&key)?; - Some(Box::new(node.neighbors.values())) + Some(Box::new( + node.neighbors + .clone() + .into_sorted_vec() + .into_iter() + .map(|n| n.id), + )) } fn storage(&self) -> Arc { @@ -98,7 +109,7 @@ impl GraphBuilder { /// Build from a [VectorStorage]. pub fn new(vectors: Arc) -> Self { Self { - nodes: BTreeMap::new(), + nodes: HashMap::new(), vectors, } } @@ -110,14 +121,14 @@ impl GraphBuilder { /// Connect from one node to another. pub fn connect(&mut self, from: u32, to: u32) -> Result<()> { - let distance: OrderedFloat = self.vectors.distance_between(from, to).into(); + let distance = self.vectors.distance_between(from, to); { let from_node = self.nodes.get_mut(&from).ok_or_else(|| Error::Index { message: format!("Node {} not found", from), location: location!(), })?; - from_node.neighbors.insert(distance, to); + from_node.add_neighbor(distance, to) } { @@ -125,7 +136,7 @@ impl GraphBuilder { message: format!("Node {} not found", to), location: location!(), })?; - to_node.neighbors.insert(distance, from); + to_node.add_neighbor(distance, from); } Ok(()) } @@ -149,7 +160,7 @@ impl GraphBuilder { let edges = node.neighbors.len(); total_edges += edges; max_edges = max_edges.max(edges); - total_distance += node.neighbors.keys().map(|d| d.0).sum::(); + total_distance += node.neighbors.iter().map(|n| n.dist.0).sum::(); } GraphBuilderStats { @@ -189,7 +200,17 @@ mod tests { builder.connect(0, 1).unwrap(); assert_eq!(builder.len(), 2); - assert_eq!(builder.neighbors(0).unwrap().collect::>(), vec![&1]); - assert_eq!(builder.neighbors(1).unwrap().collect::>(), vec![&0]); + assert_eq!(builder.neighbors(0).unwrap().collect::>(), vec![1]); + assert_eq!(builder.neighbors(1).unwrap().collect::>(), vec![0]); + + builder.insert(4); + builder.connect(0, 4).unwrap(); + assert_eq!(builder.len(), 3); + + assert_eq!( + builder.neighbors(0).unwrap().collect::>(), + vec![1, 4] + ); + assert_eq!(builder.neighbors(1).unwrap().collect::>(), vec![0]); } } diff --git a/rust/lance-index/src/vector/graph/storage.rs b/rust/lance-index/src/vector/graph/storage.rs index 8062d45207..8a4e63765a 100644 --- a/rust/lance-index/src/vector/graph/storage.rs +++ b/rust/lance-index/src/vector/graph/storage.rs @@ -28,6 +28,11 @@ pub trait DistCalculator { pub trait VectorStorage: Send + Sync { fn len(&self) -> usize; + /// Returns true if this graph is empty. + fn is_empty(&self) -> bool { + self.len() == 0 + } + /// Return the metric type of the vectors. fn metric_type(&self) -> MetricType; diff --git a/rust/lance-index/src/vector/hnsw.rs b/rust/lance-index/src/vector/hnsw.rs index 1539f1260a..e76ade8000 100644 --- a/rust/lance-index/src/vector/hnsw.rs +++ b/rust/lance-index/src/vector/hnsw.rs @@ -17,17 +17,19 @@ //! Hierarchical Navigable Small World (HNSW). //! -use std::collections::{BTreeMap, HashMap, HashSet}; +use std::collections::{BTreeMap, BinaryHeap, HashMap, HashSet}; use std::fmt::Debug; use std::ops::Range; use std::sync::Arc; +use arrow::datatypes::UInt32Type; use arrow_array::{ builder::{ListBuilder, UInt32Builder}, cast::AsArray, ListArray, RecordBatch, UInt32Array, }; use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use itertools::Itertools; use lance_core::{Error, Result}; use lance_file::{reader::FileReader, writer::FileWriter}; use lance_linalg::distance::MetricType; @@ -36,16 +38,15 @@ use serde::{Deserialize, Serialize}; use serde_json::json; use snafu::{location, Location}; -use self::storage::HnswRemappingStorage; - use super::graph::{ - builder::GraphBuilder, storage::VectorStorage, Graph, OrderedFloat, NEIGHBORS_COL, + builder::GraphBuilder, storage::VectorStorage, Graph, OrderedFloat, OrderedNode, NEIGHBORS_COL, NEIGHBORS_FIELD, }; use crate::vector::graph::beam_search; + pub mod builder; + pub use builder::HNSWBuilder; -mod storage; const HNSW_TYPE: &str = "HNSW"; const VECTOR_ID_COL: &str = "__vector_id"; @@ -63,6 +64,11 @@ lazy_static::lazy_static! { /// One level of the HNSW graph. /// struct HnswLevel { + /// Vector ID to the node index in `nodes`. + /// The node on different layer share the same Vector ID, which is the index + /// in the [VectorStorage]. + id_to_node: HashMap, + /// All the nodes in this level. // TODO: we just load the whole level into memory without pagation. nodes: RecordBatch, @@ -76,11 +82,8 @@ struct HnswLevel { /// The values of the neighbors array. neighbors_values: Arc, - /// Id of the vector in the [VectorStorage]. - vector_ids: Arc, - /// Vector storage of the graph. - vectors: Arc, + vectors: Arc, } impl HnswLevel { @@ -103,47 +106,41 @@ impl HnswLevel { .clone() .into(); let values: Arc = neighbors.values().as_primitive().clone().into(); - let vector_ids: Arc = nodes + let id_to_node = nodes .column_by_name(VECTOR_ID_COL) .unwrap() - .as_primitive() - .clone() - .into(); - let vectors = Arc::new(storage::HnswRemappingStorage::new( - vectors, - vector_ids.clone(), - )); + .as_primitive::() + .values() + .iter() + .enumerate() + .map(|(idx, &vec_id)| (vec_id, idx)) + .collect::>(); Self { nodes, neighbors, neighbors_values: values, - vector_ids, + id_to_node, vectors, } } fn from_builder(builder: &GraphBuilder, vectors: Arc) -> Result { let mut neighbours_builder = ListBuilder::new(UInt32Builder::new()); - let mut pointers_builder = UInt32Builder::new(); let mut vector_id_builder = UInt32Builder::new(); - for (_, node) in builder.nodes.iter() { - neighbours_builder.append_value(node.neighbors.values().map(|&n| Some(n))); - pointers_builder.append_value(node.pointer); + for &id in builder.nodes.keys().sorted() { + let node = builder.nodes.get(&id).unwrap(); + assert_eq!(node.id, id); + neighbours_builder.append_value(node.neighbors.clone().iter().map(|n| Some(n.id))); vector_id_builder.append_value(node.id); } - let schema = Schema::new(vec![ - NEIGHBORS_FIELD.clone(), - VECTOR_ID_FIELD.clone(), - POINTER_FIELD.clone(), - ]); + let schema = Schema::new(vec![NEIGHBORS_FIELD.clone(), VECTOR_ID_FIELD.clone()]); let batch = RecordBatch::try_new( schema.into(), vec![ Arc::new(neighbours_builder.finish()), Arc::new(vector_id_builder.finish()), - Arc::new(pointers_builder.finish()), ], )?; @@ -155,17 +152,12 @@ impl HnswLevel { } /// Range of neighbors for the given node, specified by its index. - fn neighbors_range(&self, idx: u32) -> Range { - let start = self.neighbors.value_offsets()[idx as usize] as usize; - let end = start + self.neighbors.value_length(idx as usize) as usize; + fn neighbors_range(&self, id: u32) -> Range { + let idx = self.id_to_node[&id]; + let start = self.neighbors.value_offsets()[idx] as usize; + let end = start + self.neighbors.value_length(idx) as usize; start..end } - - fn pointers(&self, ids: &[u32]) -> Vec { - ids.iter() - .map(|&id| self.vector_ids.value(id as usize)) - .collect() - } } impl Graph for HnswLevel { @@ -173,9 +165,11 @@ impl Graph for HnswLevel { self.nodes.num_rows() } - fn neighbors(&self, key: u32) -> Option + '_>> { + fn neighbors(&self, key: u32) -> Option + '_>> { let range = self.neighbors_range(key); - Some(Box::new(self.neighbors_values.values()[range].iter())) + Some(Box::new( + self.neighbors_values.values()[range].iter().copied(), + )) } fn storage(&self) -> Arc { @@ -224,8 +218,8 @@ impl HNSW { /// /// Parameters /// ---------- - /// reader : &FileReader - /// vector_storage : Arc + /// - *reader*: the file reader to read the graph from. + /// - *vector_storage*: A preloaded [VectorStorage] storage. pub async fn load(reader: &FileReader, vector_storage: Arc) -> Result { let schema = reader.schema(); let mt = if let Some(index_metadata) = schema.metadata.get("lance:index") { @@ -279,6 +273,7 @@ impl HNSW { } } + /// The Arrow schema of the graph. pub fn schema(&self) -> SchemaRef { self.levels[0].schema() } @@ -287,30 +282,30 @@ impl HNSW { /// /// Parameters /// ---------- - /// query : &[f32] - /// The query vector. - /// k : usize - /// The number of nearest neighbors to search for. - /// ef : usize - /// The size of dynamic candidate list + /// - *query* : the query vector. + /// - *k* : the number of nearest neighbors to search for. + /// - *ef* : the size of dynamic candidate list. + /// + /// Returns + /// ------- + /// A list of `(id_in_graph, distance)` pairs. Or Error if the search failed. pub fn search(&self, query: &[f32], k: usize, ef: usize) -> Result> { let mut ep = vec![self.entry_point]; let num_layers = self.levels.len(); for level in self.levels.iter().rev().take(num_layers - 1) { let candidates = beam_search(level, &ep, query, 1)?; ep = if self.use_select_heuristic { - select_neighbors_heuristic(level, query, &candidates, 1, false, true) + select_neighbors_heuristic(level, query, &candidates, 1, false) .map(|(_, id)| id) .collect() } else { select_neighbors(&candidates, 1).map(|(_, id)| id).collect() }; - ep = level.pointers(&ep); } let candidates = beam_search(&self.levels[0], &ep, query, ef)?; if self.use_select_heuristic { Ok( - select_neighbors_heuristic(&self.levels[0], query, &candidates, k, false, true) + select_neighbors_heuristic(&self.levels[0], query, &candidates, k, false) .map(|(d, u)| (u, d.into())) .collect(), ) @@ -362,57 +357,41 @@ fn select_neighbors( } /// Algorithm 4 in the HNSW paper. +/// +/// +/// Modifies to the original algorithm: +/// 1. Do not use keepPrunedConnections, we use a heap to capture nearest neighbors. fn select_neighbors_heuristic( graph: &dyn Graph, query: &[f32], orderd_candidates: &BTreeMap, k: usize, - extended_candidates: bool, - keep_pruned_connections: bool, + extend_candidates: bool, ) -> impl Iterator { - let mut results = BTreeMap::new(); - let mut w = orderd_candidates.values().cloned().collect::>(); - // W in paper - let mut candidates = orderd_candidates.clone(); - assert_eq!(w.len(), candidates.len()); + let mut heap: BinaryHeap = BinaryHeap::from_iter( + orderd_candidates + .iter() + .map(|(&d, &u)| OrderedNode { id: u, dist: d }), + ); - if extended_candidates { + if extend_candidates { let dist_calc = graph.storage().dist_calculator(query); + let mut visited = HashSet::with_capacity(orderd_candidates.len() * 64); + visited.extend(orderd_candidates.values()); orderd_candidates.iter().for_each(|(_, &u)| { if let Some(neighbors) = graph.neighbors(u) { neighbors.for_each(|n| { - if !w.contains(n) { - candidates.insert(dist_calc.distance(&[*n])[0].into(), *n); + if !visited.contains(&n) { + let d: OrderedFloat = dist_calc.distance(&[n])[0].into(); + heap.push((d, n).into()); } - w.insert(*n); + visited.insert(n); }); } }); } - let mut discarded = BTreeMap::::new(); - while !candidates.is_empty() && results.len() < k { - let (d, u) = candidates.pop_first().unwrap(); - if let Some((&key, &value)) = results.last_key_value() { - if key > d { - candidates.insert(key, value); - } else { - discarded.insert(d, u); - } - } else { - results.insert(d, u); - } - } - if keep_pruned_connections && results.len() < k { - results.extend( - discarded - .iter() - .take(k - results.len()) - .map(|(&d, &n)| (d, n)), - ); - } - - results.into_iter().take(k) + heap.into_sorted_vec().into_iter().take(k).map(|n| n.into()) } #[cfg(test)] @@ -438,7 +417,7 @@ mod tests { vec![ (OrderedFloat(1.0), 1), (OrderedFloat(2.0), 2), - (OrderedFloat(3.0), 3) + (OrderedFloat(3.0), 3), ] ); @@ -478,9 +457,9 @@ mod tests { }); hnsw.levels.iter().for_each(|layer| { - for i in 0..layer.len() { + for &i in layer.id_to_node.keys() { // If the node exist on this layer, check its out-degree. - if let Some(neighbors) = layer.neighbors(i as u32) { + if let Some(neighbors) = layer.neighbors(i) { let cnt = neighbors.count(); assert!(cnt <= MAX_EDGES, "actual {}, max_edges: {}", cnt, MAX_EDGES); } diff --git a/rust/lance-index/src/vector/hnsw/builder.rs b/rust/lance-index/src/vector/hnsw/builder.rs index 86be889698..0e4a028de5 100644 --- a/rust/lance-index/src/vector/hnsw/builder.rs +++ b/rust/lance-index/src/vector/hnsw/builder.rs @@ -12,11 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! Builder of Hnsw Graph. + use std::cmp::min; -use std::collections::HashMap; use std::sync::Arc; use lance_core::Result; +use log::{info, log_enabled, Level::Info}; use rand::{thread_rng, Rng}; use super::super::graph::{beam_search, memory::InMemoryVectorStorage}; @@ -29,6 +31,9 @@ use crate::vector::hnsw::HnswLevel; /// /// Currently, the HNSW graph is fully built in memory. /// +/// During the build, the graph is built layer by layer. +/// +/// Each node in the graph has a global ID which is the index on the base layer. pub struct HNSWBuilder { /// max level of max_level: u16, @@ -46,6 +51,10 @@ pub struct HNSWBuilder { entry_point: u32, + extend_candidates: bool, + + log_base: f32, + use_select_heuristic: bool, } @@ -54,22 +63,26 @@ impl HNSWBuilder { pub fn new(vectors: Arc) -> Self { Self { max_level: 8, - m_max: 32, + m_max: 64, ef_construction: 100, vectors, levels: vec![], entry_point: 0, + extend_candidates: false, + log_base: 10.0, use_select_heuristic: true, } } /// The maximum level of the graph. + /// The default value is `8`. pub fn max_level(mut self, max_level: u16) -> Self { self.max_level = max_level; self } /// The maximum number of connections for each node per layer. + /// The default value is `64`. pub fn max_num_edges(mut self, m_max: usize) -> Self { self.m_max = m_max; self @@ -77,11 +90,26 @@ impl HNSWBuilder { /// Number of candidates to be considered when searching for the nearest neighbors /// during the construction of the graph. + /// + /// The default value is `100`. pub fn ef_construction(mut self, ef_construction: usize) -> Self { self.ef_construction = ef_construction; self } + /// Whether to expend to search candidate neighbors during heuristic search. + /// + /// The default value is `false`. + /// + /// See `extendCandidates` parameter in the paper (Algorithm 4) + pub fn extend_candidates(mut self, flag: bool) -> Self { + self.extend_candidates = flag; + self + } + + /// Use select heuristic when searching for the nearest neighbors. + /// + /// See algorithm 4 in HNSW paper. pub fn use_select_heuristic(mut self, flag: bool) -> Self { self.use_select_heuristic = flag; self @@ -92,18 +120,21 @@ impl HNSWBuilder { self.levels[0].len() } - #[inline] - fn m_l(&self) -> f32 { - 1.0 / (self.len() as f32).ln() - } - - /// new node's level + /// New node's level /// /// See paper `Algorithm 1` fn random_level(&self) -> u16 { let mut rng = thread_rng(); - let r = rng.gen::(); - min((-r.ln() * self.m_l()).floor() as u16, self.max_level) + // This is different to the paper. + // We use log10 instead of log(e), so each layer has about 1/10 of its bottom layer. + let m = self.vectors.len(); + min( + (m as f32).log(self.log_base).ceil() as u16 + - (rng.gen::() * self.vectors.len() as f32) + .log(self.log_base) + .ceil() as u16, + self.max_level, + ) } /// Insert one node. @@ -130,22 +161,21 @@ impl HNSWBuilder { for cur_level in self.levels.iter().rev().take(levels_to_search) { let candidates = beam_search(cur_level, &ep, vector, self.ef_construction)?; ep = if self.use_select_heuristic { - select_neighbors_heuristic(cur_level, query, &candidates, 1, true, true) - .map(|(_, id)| cur_level.nodes[&id].pointer) + select_neighbors_heuristic(cur_level, query, &candidates, 1, self.extend_candidates) + .map(|(_, id)| id) .collect() } else { - select_neighbors(&candidates, 1) - .map(|(_, id)| cur_level.nodes[&id].pointer) - .collect() + select_neighbors(&candidates, 1).map(|(_, id)| id).collect() }; } - let m = self.levels[0].nodes.len(); + let m = self.len(); for cur_level in self.levels.iter_mut().rev().skip(levels_to_search) { cur_level.insert(node); let candidates = beam_search(cur_level, &ep, vector, self.ef_construction)?; let neighbours: Vec<_> = if self.use_select_heuristic { - select_neighbors_heuristic(cur_level, query, &candidates, m, true, true).collect() + select_neighbors_heuristic(cur_level, query, &candidates, m, self.extend_candidates) + .collect() } else { select_neighbors(&candidates, m).collect() }; @@ -157,10 +187,7 @@ impl HNSWBuilder { cur_level.prune(nb, self.m_max)?; } cur_level.prune(node, self.m_max)?; - ep = candidates - .values() - .map(|id| cur_level.nodes[id].pointer) - .collect(); + ep = candidates.values().copied().collect(); } if level > self.levels.len() as u16 { @@ -175,7 +202,7 @@ impl HNSWBuilder { self.build_with(self.vectors.clone()) } - /// Build the graph, with the provided [`VectorStorage`] as backing storage for HNSW graph. + /// Build the graph, with the provided [VectorStorage] as backing storage for HNSW graph. pub fn build_with(&mut self, storage: Arc) -> Result { log::info!( "Building HNSW graph: metric_type={}, max_levels={}, m_max={}, ef_construction={}", @@ -194,7 +221,11 @@ impl HNSWBuilder { self.insert(i as u32)?; } - remapping_levels(&mut self.levels); + if log_enabled!(Info) { + for (i, level) in self.levels.iter().enumerate() { + info!("HNSW level {}: {:#?}", i, level.stats()); + } + } let graphs = self .levels @@ -209,89 +240,3 @@ impl HNSWBuilder { )) } } - -/// Because each level is stored as a separate continous RecordBatch. We need to remap the pointers -/// to the nodes in the previous level to the index in the current RecordBatch. -fn remapping_levels(levels: &mut [GraphBuilder]) { - for i in 1..levels.len() { - let prev_level = &levels[i - 1]; - let mapping = prev_level - .nodes - .keys() - .enumerate() - .map(|(i, &id)| (id, i as u32)) - .collect::>(); - let cur_level = &mut levels[i]; - let current_mapping = cur_level - .nodes - .keys() - .enumerate() - .map(|(idx, &id)| (id, idx as u32)) - .collect::>(); - for node in cur_level.nodes.values_mut() { - node.pointer = *mapping.get(&node.id).expect("Expect the pointer exists"); - - // Remapping the neighbors within this level of graph. - node.neighbors = node - .neighbors - .iter() - .map(|(d, n)| { - ( - *d, - *current_mapping.get(n).expect("Expect the pointer exists"), - ) - }) - .collect(); - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - use std::sync::Arc; - - use arrow_array::types::Float32Type; - use lance_linalg::{distance::MetricType, matrix::MatrixView}; - use lance_testing::datagen::generate_random_array; - - #[test] - fn test_remapping_levels() { - let data = generate_random_array(8 * 100); - let mat = MatrixView::::new(Arc::new(data), 8); - let storage = Arc::new(InMemoryVectorStorage::new(mat.into(), MetricType::L2)); - let mut level0 = GraphBuilder::new(storage.clone()); - for i in 0..100 { - level0.insert(i as u32); - } - let mut level1 = GraphBuilder::new(storage.clone()); - for i in [0, 5, 10, 15, 20, 30, 40, 50] { - level1.insert(i as u32); - } - let mut level2 = GraphBuilder::new(storage.clone()); - for i in [0, 10, 20, 50] { - level2.insert(i as u32); - } - let mut levels = [level0, level1, level2]; - remapping_levels(&mut levels); - assert_eq!( - levels[1] - .nodes - .values() - .map(|n| n.pointer) - .collect::>(), - vec![0, 5, 10, 15, 20, 30, 40, 50] - ); - assert_eq!( - levels[2] - .nodes - .values() - .map(|n| n.pointer) - .collect::>(), - vec![0, 2, 4, 7] - ); - - println!("{:?}", levels[2].nodes); - } -} diff --git a/rust/lance-index/src/vector/hnsw/storage.rs b/rust/lance-index/src/vector/hnsw/storage.rs deleted file mode 100644 index fddd8d6c28..0000000000 --- a/rust/lance-index/src/vector/hnsw/storage.rs +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright 2024 Lance Developers. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::sync::Arc; - -use arrow_array::UInt32Array; -use lance_linalg::distance::MetricType; - -use crate::vector::graph::storage::{DistCalculator, VectorStorage}; - -/// Remapping level id to vector id. -/// -/// Each node in one level of HNSW has an vector id. -/// Which is not the same as the node id in this level of graph -pub(super) struct HnswRemappingStorage { - raw_vectors: Arc, - - vector_ids: Arc, -} - -impl HnswRemappingStorage { - pub fn new(raw_vectors: Arc, vector_ids: Arc) -> Self { - Self { - raw_vectors, - vector_ids, - } - } -} - -impl VectorStorage for HnswRemappingStorage { - fn len(&self) -> usize { - self.raw_vectors.len() - } - - fn metric_type(&self) -> MetricType { - self.raw_vectors.metric_type() - } - - fn dist_calculator(&self, query: &[f32]) -> Box { - let calc = self.raw_vectors.dist_calculator(query); - Box::new(HnswDistCalculator { - raw_calculator: calc, - vector_ids: self.vector_ids.clone(), - }) - } -} - -struct HnswDistCalculator { - raw_calculator: Box, - vector_ids: Arc, -} - -impl DistCalculator for HnswDistCalculator { - fn distance(&self, vector_ids: &[u32]) -> Vec { - let vector_ids = vector_ids - .iter() - .map(|&i| self.vector_ids.value(i as usize)) - .collect::>(); - self.raw_calculator.distance(&vector_ids) - } -} diff --git a/rust/lance/examples/hnsw.rs b/rust/lance/examples/hnsw.rs index c0412d1f92..ff8681f5e9 100644 --- a/rust/lance/examples/hnsw.rs +++ b/rust/lance/examples/hnsw.rs @@ -40,6 +40,13 @@ struct Args { #[arg(long, default_value = "100")] ef: usize, + + /// Max number of edges of each node. + #[arg(long, default_value = "64")] + max_edges: usize, + + #[arg(long, default_value = "7")] + max_level: u16, } fn ground_truth(mat: &MatrixView, query: &[f32], k: usize) -> HashSet { @@ -83,32 +90,31 @@ async fn main() { let k = 10; let gt = ground_truth(&mat, q, k); - for level in [4, 8, 16, 32] { - for ef_construction in [50, 100, 200, 400] { - let now = std::time::Instant::now(); - let hnsw = HNSWBuilder::new(vector_store.clone()) - .max_level(level) - .ef_construction(ef_construction) - .build() - .unwrap(); - let construct_time = now.elapsed().as_secs_f32(); - let now = std::time::Instant::now(); - let results: HashSet = hnsw - .search(q, k, args.ef) - .unwrap() - .iter() - .map(|(i, _)| *i) - .collect(); - let search_time = now.elapsed().as_micros(); - println!( - "level={}, ef_construct={}, ef={} recall={}: construct={:.3}s search={:.3} us", - level, - ef_construction, - args.ef, - results.intersection(>).count() as f32 / k as f32, - construct_time, - search_time - ); - } + for ef_construction in [50, 100, 200, 400] { + let now = std::time::Instant::now(); + let hnsw = HNSWBuilder::new(vector_store.clone()) + .max_level(args.max_level) + .max_num_edges(args.max_edges) + .ef_construction(ef_construction) + .build() + .unwrap(); + let construct_time = now.elapsed().as_secs_f32(); + let now = std::time::Instant::now(); + let results: HashSet = hnsw + .search(q, k, args.ef) + .unwrap() + .iter() + .map(|(i, _)| *i) + .collect(); + let search_time = now.elapsed().as_micros(); + println!( + "level={}, ef_construct={}, ef={} recall={}: construct={:.3}s search={:.3} us", + args.max_level, + ef_construction, + args.ef, + results.intersection(>).count() as f32 / k as f32, + construct_time, + search_time + ); } }