Skip to content

Commit

Permalink
perf: impl heuristic pruning for HNSW (#2171)
Browse files Browse the repository at this point in the history
improve the recall from 90% to 99%,
according to the HNSW paper, this also improve performance, especially
for highly clustered dataset

---------

Signed-off-by: BubbleCal <[email protected]>
  • Loading branch information
BubbleCal authored Apr 10, 2024
1 parent d189ca2 commit f5980e9
Show file tree
Hide file tree
Showing 11 changed files with 154 additions and 94 deletions.
2 changes: 1 addition & 1 deletion rust/lance-index/benches/hnsw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ fn bench_hnsw(c: &mut Criterion) {
.search(query, K, 300, None)
.unwrap()
.iter()
.map(|(i, _)| *i)
.map(|node| node.id)
.collect();

assert_eq!(uids.len(), K);
Expand Down
36 changes: 21 additions & 15 deletions rust/lance-index/src/vector/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ impl<I> From<I> for GraphNode<I> {
/// A wrapper for f32 to make it ordered, so that we can put it into
/// a BTree or Heap
#[derive(Debug, PartialEq, Clone, Copy)]
pub(crate) struct OrderedFloat(pub f32);
pub struct OrderedFloat(pub f32);

impl PartialOrd for OrderedFloat {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Expand Down Expand Up @@ -92,11 +92,17 @@ impl From<OrderedFloat> for f32 {
}

#[derive(Debug, Eq, PartialEq, Clone)]
pub(crate) struct OrderedNode {
pub struct OrderedNode {
pub id: u32,
pub dist: OrderedFloat,
}

impl OrderedNode {
pub fn new(id: u32, dist: OrderedFloat) -> Self {
Self { id, dist }
}
}

impl PartialOrd for OrderedNode {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.dist.cmp(&other.dist))
Expand Down Expand Up @@ -181,37 +187,37 @@ pub(super) fn beam_search(
k: usize,
dist_calc: Option<Arc<dyn DistCalculator>>,
bitset: Option<&roaring::bitmap::RoaringBitmap>,
) -> Result<BinaryHeap<(OrderedFloat, u32)>> {
) -> Result<BinaryHeap<OrderedNode>> {
let mut visited: HashSet<_> = start.iter().copied().collect();
let dist_calc = dist_calc.unwrap_or_else(|| graph.storage().dist_calculator(query).into());
let mut candidates: BinaryHeap<Reverse<(OrderedFloat, u32)>> = dist_calc
let mut candidates: BinaryHeap<Reverse<OrderedNode>> = dist_calc
.distance(start)
.iter()
.zip(start)
.map(|(&dist, id)| Reverse((dist.into(), *id)))
.map(|(&dist, id)| Reverse((dist.into(), *id).into()))
.collect();
let mut results: BinaryHeap<(OrderedFloat, u32)> = candidates
let mut results: BinaryHeap<OrderedNode> = candidates
.clone()
.into_iter()
.filter(|node| {
bitset
.map(|bitset| bitset.contains(node.0 .1))
.map(|bitset| bitset.contains(node.0.id))
.unwrap_or(true)
})
.map(|v| v.0)
.collect();

while !candidates.is_empty() {
let (dist, current) = candidates.pop().expect("candidates is empty").0;
let current = candidates.pop().expect("candidates is empty").0;
let furthest = results
.peek()
.map(|kv| kv.0)
.map(|node| node.dist)
.unwrap_or(OrderedFloat(f32::INFINITY));
if dist > furthest {
if current.dist > furthest {
break;
}
let neighbors = graph.neighbors(current).ok_or_else(|| Error::Index {
message: format!("Node {} does not exist in the graph", current),
let neighbors = graph.neighbors(current.id).ok_or_else(|| Error::Index {
message: format!("Node {} does not exist in the graph", current.id),
location: location!(),
})?;

Expand All @@ -222,20 +228,20 @@ pub(super) fn beam_search(
visited.insert(neighbor);
let furthest = results
.peek()
.map(|kv| kv.0)
.map(|node| node.dist)
.unwrap_or(OrderedFloat(f32::INFINITY));
let dist = dist_calc.distance(&[neighbor])[0].into();
if dist <= furthest || results.len() < k {
if bitset
.map(|bitset| bitset.contains(neighbor))
.unwrap_or(true)
{
results.push((dist, neighbor));
results.push((dist, neighbor).into());
if results.len() > k {
results.pop();
}
}
candidates.push(Reverse((dist, neighbor)));
candidates.push(Reverse((dist, neighbor).into()));
}
}
}
Expand Down
34 changes: 19 additions & 15 deletions rust/lance-index/src/vector/graph/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use snafu::{location, Location};
use super::OrderedNode;
use super::{memory::InMemoryVectorStorage, Graph, GraphNode, OrderedFloat};
use crate::vector::graph::storage::VectorStorage;
use crate::vector::hnsw::select_neighbors_heuristic;

/// GraphNode during build.
#[derive(Debug, Clone)]
Expand All @@ -43,15 +44,6 @@ impl GraphBuilderNode {
fn add_neighbor(&mut self, distance: OrderedFloat, id: u32) {
self.neighbors.push(OrderedNode { dist: distance, id });
}

/// Prune the node and only keep `max_edges` edges.
///
/// Returns the ids of pruned neighbors.
fn prune(&mut self, max_edges: usize) {
while self.neighbors.len() > max_edges {
self.neighbors.pop();
}
}
}

impl From<&GraphBuilderNode> for GraphNode<u32> {
Expand Down Expand Up @@ -99,7 +91,7 @@ impl Graph for GraphBuilder {
}

fn storage(&self) -> Arc<dyn VectorStorage> {
self.vectors.clone() as Arc<dyn VectorStorage>
self.vectors.clone()
}
}

Expand Down Expand Up @@ -140,11 +132,23 @@ impl GraphBuilder {
}

pub fn prune(&mut self, node: u32, max_edges: usize) -> Result<()> {
let node = self.nodes.get_mut(&node).ok_or_else(|| Error::Index {
message: format!("Node {} not found", node),
location: location!(),
})?;
node.prune(max_edges);
let vector = self.vectors.vector(node);

let neighbors = &self
.nodes
.get(&node)
.ok_or_else(|| Error::Index {
message: format!("Node {} not found", node),
location: location!(),
})?
.neighbors;

let pruned_neighbors =
select_neighbors_heuristic(self, vector, neighbors, max_edges, false).collect();

self.nodes
.entry(node)
.and_modify(|node| node.neighbors = pruned_neighbors);
Ok(())
}

Expand Down
4 changes: 4 additions & 0 deletions rust/lance-index/src/vector/graph/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ impl InMemoryVectorStorage {
}

impl VectorStorage for InMemoryVectorStorage {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn len(&self) -> usize {
self.vectors.num_rows()
}
Expand Down
4 changes: 4 additions & 0 deletions rust/lance-index/src/vector/graph/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::any::Any;

use lance_linalg::distance::MetricType;

pub trait DistCalculator {
Expand All @@ -26,6 +28,8 @@ pub trait DistCalculator {
///
/// TODO: should we rename this to "VectorDistance"?;
pub trait VectorStorage: Send + Sync {
fn as_any(&self) -> &dyn Any;

fn len(&self) -> usize;

/// Returns true if this graph is empty.
Expand Down
112 changes: 72 additions & 40 deletions rust/lance-index/src/vector/hnsw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
//! Hierarchical Navigable Small World (HNSW).
//!
use std::cmp::Reverse;
use std::collections::{BinaryHeap, HashMap, HashSet};
use std::fmt::Debug;
use std::ops::Range;
Expand Down Expand Up @@ -44,6 +45,8 @@ use snafu::{location, Location};

use self::builder::HNSW_METADATA_KEY;

use super::graph::memory::InMemoryVectorStorage;
use super::graph::OrderedNode;
use super::graph::{
builder::GraphBuilder,
greedy_search,
Expand Down Expand Up @@ -352,7 +355,7 @@ impl HNSW {
k: usize,
ef: usize,
bitset: Option<RoaringBitmap>,
) -> Result<Vec<(u32, f32)>> {
) -> Result<Vec<OrderedNode>> {
let mut ep = self.entry_point;
let num_layers = self.levels.len();

Expand All @@ -371,9 +374,7 @@ impl HNSW {
Some(dist_calc),
bitset.as_ref(),
)?;
Ok(select_neighbors(&candidates, k)
.map(|(d, u)| (u, d.into()))
.collect())
Ok(select_neighbors(&candidates, k).cloned().collect())
}

/// Returns the metadata of this [`HNSW`].
Expand Down Expand Up @@ -477,85 +478,116 @@ impl HNSW {
///
/// Algorithm 3 in the HNSW paper.
fn select_neighbors(
orderd_candidates: &BinaryHeap<(OrderedFloat, u32)>,
orderd_candidates: &BinaryHeap<OrderedNode>,
k: usize,
) -> impl Iterator<Item = (OrderedFloat, u32)> + '_ {
orderd_candidates
.iter()
.sorted()
.take(k)
.map(|(d, u)| (*d, *u))
) -> impl Iterator<Item = &OrderedNode> + '_ {
orderd_candidates.iter().sorted().take(k)
}

/// Algorithm 4 in the HNSW paper.
///
///
/// Modifies to the original algorithm:
/// 1. Do not use keepPrunedConnections, we use a heap to capture nearest neighbors.
fn select_neighbors_heuristic(
graph: &dyn Graph,
/// NOTE: the result is not ordered
pub(crate) fn select_neighbors_heuristic(
graph: &GraphBuilder,
query: &[f32],
orderd_candidates: &BinaryHeap<(OrderedFloat, u32)>,
orderd_candidates: &BinaryHeap<OrderedNode>,
k: usize,
extend_candidates: bool,
) -> impl Iterator<Item = (OrderedFloat, u32)> {
let mut heap = orderd_candidates.clone();
) -> impl Iterator<Item = OrderedNode> {
let mut w = orderd_candidates
.iter()
.cloned()
.map(Reverse)
.collect::<BinaryHeap<_>>();

if extend_candidates {
let dist_calc = graph.storage().dist_calculator(query);
let mut visited = HashSet::with_capacity(orderd_candidates.len() * 64);
visited.extend(orderd_candidates.iter().map(|(_, u)| *u));
orderd_candidates.iter().sorted().rev().for_each(|(_, u)| {
if let Some(neighbors) = graph.neighbors(*u) {
let mut visited = HashSet::with_capacity(orderd_candidates.len() * k);
visited.extend(orderd_candidates.iter().map(|node| node.id));
orderd_candidates.iter().sorted().rev().for_each(|node| {
if let Some(neighbors) = graph.neighbors(node.id) {
neighbors.for_each(|n| {
if !visited.contains(&n) {
let d: OrderedFloat = dist_calc.distance(&[n])[0].into();
heap.push((d, n));
w.push(Reverse((d, n).into()));
}
visited.insert(n);
});
}
});
}

heap.into_sorted_vec().into_iter().take(k)
let mut results: Vec<OrderedNode> = Vec::with_capacity(k);
let mut discarded = Vec::new();
let storage = graph.storage();
let storage = storage
.as_any()
.downcast_ref::<InMemoryVectorStorage>()
.unwrap();
while !w.is_empty() && results.len() < k {
let u = w.pop().unwrap().0;

if results.is_empty()
|| results
.iter()
.all(|v| u.dist < OrderedFloat(storage.distance_between(u.id, v.id)))
{
results.push(u);
} else {
discarded.push(u);
}
}

while results.len() < k && !discarded.is_empty() {
results.push(discarded.pop().unwrap());
}

results.into_iter()
}

#[cfg(test)]
mod tests {
use super::*;

use crate::vector::graph::memory::InMemoryVectorStorage;
use arrow_array::types::Float32Type;
use lance_linalg::matrix::MatrixView;
use lance_testing::datagen::generate_random_array;
use tests::builder::HnswBuildParams;

#[test]
fn test_select_neighbors() {
let candidates: BinaryHeap<(OrderedFloat, u32)> =
(1..6).map(|i| (OrderedFloat(i as f32), i)).collect();
let candidates: BinaryHeap<OrderedNode> =
(1..6).map(|i| (OrderedFloat(i as f32), i).into()).collect();

let result = select_neighbors(&candidates, 3).collect::<Vec<_>>();
let result = select_neighbors(&candidates, 3)
.cloned()
.collect::<Vec<_>>();
assert_eq!(
result,
vec![
(OrderedFloat(1.0), 1),
(OrderedFloat(2.0), 2),
(OrderedFloat(3.0), 3),
OrderedNode::new(1, OrderedFloat(1.0)),
OrderedNode::new(2, OrderedFloat(2.0)),
OrderedNode::new(3, OrderedFloat(3.0)),
]
);

assert_eq!(select_neighbors(&candidates, 0).collect::<Vec<_>>(), vec![]);
assert_eq!(
select_neighbors(&candidates, 0)
.cloned()
.collect::<Vec<_>>(),
vec![]
);

assert_eq!(
select_neighbors(&candidates, 8).collect::<Vec<_>>(),
select_neighbors(&candidates, 8)
.cloned()
.collect::<Vec<_>>(),
vec![
(OrderedFloat(1.0), 1),
(OrderedFloat(2.0), 2),
(OrderedFloat(3.0), 3),
(OrderedFloat(4.0), 4),
(OrderedFloat(5.0), 5),
OrderedNode::new(1, OrderedFloat(1.0)),
OrderedNode::new(2, OrderedFloat(2.0)),
OrderedNode::new(3, OrderedFloat(3.0)),
OrderedNode::new(4, OrderedFloat(4.0)),
OrderedNode::new(5, OrderedFloat(5.0)),
]
);
}
Expand Down Expand Up @@ -632,7 +664,7 @@ mod tests {
.search(q, K, 128, None)
.unwrap()
.iter()
.map(|(i, _)| *i)
.map(|node| node.id)
.collect();
let gt = ground_truth(&mat, q, K);
let recall = results.intersection(&gt).count() as f32 / K as f32;
Expand Down
Loading

0 comments on commit f5980e9

Please sign in to comment.