Skip to content

Commit

Permalink
feat: implement heuristic select neighbor algorithm (#1991)
Browse files Browse the repository at this point in the history
Implement the `SELECT-NEIGHBORS-HEURISTIC` algorithm from the HNSW
paper.

---------

Co-authored-by: Weston Pace <[email protected]>
  • Loading branch information
eddyxu and westonpace authored Feb 24, 2024
1 parent eef90e0 commit 5c1f46c
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 24 deletions.
5 changes: 3 additions & 2 deletions rust/lance-index/examples/hnsw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,15 @@ async fn main() {
.await
.expect("Failed to open dataset");
println!("Dataset schema: {:#?}", dataset.schema());
let column = args.column.as_deref().unwrap_or("vector");
let batches = dataset
.scan()
.project(&[args.column.as_deref().unwrap_or("vector")])
.project(&[column])
.unwrap()
.try_into_stream()
.await
.unwrap()
.then(|batch| async move { batch.unwrap().column_by_name("openai").unwrap().clone() })
.then(|batch| async move { batch.unwrap().column_by_name(column).unwrap().clone() })
.collect::<Vec<_>>()
.await;
let arrs = batches.iter().map(|b| b.as_ref()).collect::<Vec<_>>();
Expand Down
29 changes: 29 additions & 0 deletions rust/lance-index/src/vector/graph/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,35 @@ impl GraphBuilder {
node.prune(max_edges);
Ok(())
}

#[allow(dead_code)]
pub(crate) fn stats(&self) -> GraphBuilderStats {
let mut max_edges = 0;
let mut total_edges = 0;
let mut total_distance = 0.0;

for node in self.nodes.values() {
let edges = node.neighbors.len();
total_edges += edges;
max_edges = max_edges.max(edges);
total_distance += node.neighbors.keys().map(|d| d.0).sum::<f32>();
}

GraphBuilderStats {
num_nodes: self.nodes.len(),
max_edges,
mean_edges: total_edges as f32 / self.nodes.len() as f32,
mean_distance: total_distance / total_edges as f32,
}
}
}

#[derive(Debug)]
pub struct GraphBuilderStats {
pub num_nodes: usize,
pub max_edges: usize,
pub mean_edges: f32,
pub mean_distance: f32,
}

#[cfg(test)]
Expand Down
95 changes: 87 additions & 8 deletions rust/lance-index/src/vector/hnsw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
//! Hierarchical Navigable Small World (HNSW).
//!
use std::collections::{BTreeMap, HashMap};
use std::collections::{BTreeMap, HashMap, HashSet};
use std::fmt::Debug;
use std::ops::Range;
use std::sync::Arc;
Expand Down Expand Up @@ -190,6 +190,8 @@ pub struct HNSW {
metric_type: MetricType,
/// Entry point of the graph.
entry_point: u32,
/// Whether to use the heuristic to select neighbors (Algorithm 4 or 3 in the paper).
use_select_heuristic: bool,
}

impl Debug for HNSW {
Expand Down Expand Up @@ -259,14 +261,21 @@ impl HNSW {
levels,
metric_type: mt,
entry_point: hnsw_metadata.entry_point,
use_select_heuristic: true,
})
}

fn from_builder(levels: Vec<HnswLevel>, entry_point: u32, metric_type: MetricType) -> Self {
fn from_builder(
levels: Vec<HnswLevel>,
entry_point: u32,
metric_type: MetricType,
use_select_heuristic: bool,
) -> Self {
Self {
levels,
metric_type,
entry_point,
use_select_heuristic,
}
}

Expand All @@ -289,13 +298,27 @@ impl HNSW {
let num_layers = self.levels.len();
for level in self.levels.iter().rev().take(num_layers - 1) {
let candidates = beam_search(level, &ep, query, 1)?;
ep = select_neighbors(&candidates, 1).map(|(_, id)| id).collect();
ep = if self.use_select_heuristic {
select_neighbors_heuristic(level, query, &candidates, 1, false, true)
.map(|(_, id)| id)
.collect()
} else {
select_neighbors(&candidates, 1).map(|(_, id)| id).collect()
};
ep = level.pointers(&ep);
}
let candidates = beam_search(&self.levels[0], &ep, query, ef)?;
Ok(select_neighbors(&candidates, k)
.map(|(d, u)| (u, d.into()))
.collect())
if self.use_select_heuristic {
Ok(
select_neighbors_heuristic(&self.levels[0], query, &candidates, k, false, true)
.map(|(d, u)| (u, d.into()))
.collect(),
)
} else {
Ok(select_neighbors(&candidates, k)
.map(|(d, u)| (u, d.into()))
.collect())
}
}

/// Write the HNSW graph to a Lance file.
Expand Down Expand Up @@ -338,6 +361,60 @@ fn select_neighbors(
orderd_candidates.iter().take(k).map(|(&d, &u)| (d, u))
}

/// Algorithm 4 in the HNSW paper.
fn select_neighbors_heuristic(
graph: &dyn Graph,
query: &[f32],
orderd_candidates: &BTreeMap<OrderedFloat, u32>,
k: usize,
extended_candidates: bool,
keep_pruned_connections: bool,
) -> impl Iterator<Item = (OrderedFloat, u32)> {
let mut results = BTreeMap::new();
let mut w = orderd_candidates.values().cloned().collect::<HashSet<_>>();
// W in paper
let mut candidates = orderd_candidates.clone();
assert_eq!(w.len(), candidates.len());

if extended_candidates {
let dist_calc = graph.storage().dist_calculator(query);
orderd_candidates.iter().for_each(|(_, &u)| {
if let Some(neighbors) = graph.neighbors(u) {
neighbors.for_each(|n| {
if !w.contains(n) {
candidates.insert(dist_calc.distance(&[*n])[0].into(), *n);
}
w.insert(*n);
});
}
});
}

let mut discarded = BTreeMap::<OrderedFloat, u32>::new();
while !candidates.is_empty() && results.len() < k {
let (d, u) = candidates.pop_first().unwrap();
if let Some((&key, &value)) = results.last_key_value() {
if key > d {
candidates.insert(key, value);
} else {
discarded.insert(d, u);
}
} else {
results.insert(d, u);
}
}
if keep_pruned_connections && results.len() < k {
results.extend(
discarded
.iter()
.take(k - results.len())
.map(|(&d, &n)| (d, n)),
);
}

results.into_iter().take(k)
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -404,7 +481,8 @@ mod tests {
for i in 0..layer.len() {
// If the node exist on this layer, check its out-degree.
if let Some(neighbors) = layer.neighbors(i as u32) {
assert!(neighbors.count() <= MAX_EDGES);
let cnt = neighbors.count();
assert!(cnt <= MAX_EDGES, "actual {}, max_edges: {}", cnt, MAX_EDGES);
}
}
});
Expand All @@ -424,7 +502,7 @@ mod tests {
#[test]
fn test_search() {
const DIM: usize = 32;
const TOTAL: usize = 2048;
const TOTAL: usize = 1024;
const MAX_EDGES: usize = 32;
const K: usize = 10;

Expand All @@ -436,6 +514,7 @@ mod tests {
let hnsw = HNSWBuilder::new(vectors.clone())
.max_num_edges(MAX_EDGES)
.ef_construction(100)
.max_level(4)
.build()
.unwrap();

Expand Down
59 changes: 45 additions & 14 deletions rust/lance-index/src/vector/hnsw/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ use lance_core::Result;
use rand::{thread_rng, Rng};

use super::super::graph::{beam_search, memory::InMemoryVectorStorage};
use super::{select_neighbors, HNSW};
use super::{select_neighbors, select_neighbors_heuristic, HNSW};
use crate::vector::graph::Graph;
use crate::vector::graph::{builder::GraphBuilder, storage::VectorStorage};
use crate::vector::hnsw::HnswLevel;

Expand All @@ -32,9 +33,6 @@ pub struct HNSWBuilder {
/// max level of
max_level: u16,

/// M_l parameter in the paper.
m_level_decay: f32,

/// max number of connections ifor each element per layers.
m_max: usize,

Expand All @@ -47,6 +45,8 @@ pub struct HNSWBuilder {
levels: Vec<GraphBuilder>,

entry_point: u32,

use_select_heuristic: bool,
}

impl HNSWBuilder {
Expand All @@ -59,14 +59,13 @@ impl HNSWBuilder {
vectors,
levels: vec![],
entry_point: 0,
m_level_decay: 1.0 / 8_f32.ln(),
use_select_heuristic: true,
}
}

/// The maximum level of the graph.
pub fn max_level(mut self, max_level: u16) -> Self {
self.max_level = max_level;
self.m_level_decay = 1.0 / (max_level as f32).ln();
self
}

Expand All @@ -83,16 +82,28 @@ impl HNSWBuilder {
self
}

pub fn use_select_heuristic(mut self, flag: bool) -> Self {
self.use_select_heuristic = flag;
self
}

#[inline]
fn len(&self) -> usize {
self.levels[0].len()
}

#[inline]
fn m_l(&self) -> f32 {
1.0 / (self.len() as f32).ln()
}

/// new node's level
///
/// See paper `Algorithm 1`
fn random_level(&self) -> u16 {
let mut rng = thread_rng();
let r = rng.gen::<f32>();
min(
(-r.ln() * self.m_level_decay).floor() as u16,
self.max_level,
)
min((-r.ln() * self.m_l()).floor() as u16, self.max_level)
}

/// Insert one node.
Expand All @@ -106,6 +117,7 @@ impl HNSWBuilder {
0
};
let mut ep = vec![self.entry_point];
let query = self.vectors.vector(self.entry_point);

//
// Search for entry point in paper.
Expand All @@ -117,20 +129,38 @@ impl HNSWBuilder {
// ```
for cur_level in self.levels.iter().rev().take(levels_to_search) {
let candidates = beam_search(cur_level, &ep, vector, self.ef_construction)?;
let neighbours = select_neighbors(&candidates, 1);
ep = neighbours.map(|(_, id)| id).collect();
ep = if self.use_select_heuristic {
select_neighbors_heuristic(cur_level, query, &candidates, 1, true, true)
.map(|(_, id)| cur_level.nodes[&id].pointer)
.collect()
} else {
select_neighbors(&candidates, 1)
.map(|(_, id)| cur_level.nodes[&id].pointer)
.collect()
};
}

let m = self.levels[0].nodes.len();
for cur_level in self.levels.iter_mut().rev().skip(levels_to_search) {
cur_level.insert(node);
let candidates = beam_search(cur_level, &ep, vector, self.ef_construction)?;
let neighbours = select_neighbors(&candidates, self.m_max).collect::<Vec<_>>();
let neighbours: Vec<_> = if self.use_select_heuristic {
select_neighbors_heuristic(cur_level, query, &candidates, m, true, true).collect()
} else {
select_neighbors(&candidates, m).collect()
};

for (_, nb) in neighbours.iter() {
cur_level.connect(node, *nb)?;
}
for (_, nb) in neighbours {
cur_level.prune(nb, self.m_max)?;
}
ep = candidates.values().copied().collect::<Vec<_>>();
cur_level.prune(node, self.m_max)?;
ep = candidates
.values()
.map(|id| cur_level.nodes[id].pointer)
.collect();
}

if level > self.levels.len() as u16 {
Expand Down Expand Up @@ -175,6 +205,7 @@ impl HNSWBuilder {
graphs,
self.entry_point,
self.vectors.metric_type(),
self.use_select_heuristic,
))
}
}
Expand Down

0 comments on commit 5c1f46c

Please sign in to comment.