Skip to content

Commit

Permalink
chore: improve hnsw build time (#1996)
Browse files Browse the repository at this point in the history
  • Loading branch information
eddyxu authored Feb 26, 2024
1 parent 7c6afba commit 4fe5e43
Show file tree
Hide file tree
Showing 8 changed files with 243 additions and 327 deletions.
5 changes: 3 additions & 2 deletions rust/lance-index/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,18 @@ arrow-schema.workspace = true
arrow-select.workspace = true
async-recursion.workspace = true
async-trait.workspace = true
datafusion.workspace = true
datafusion-common.workspace = true
datafusion-expr.workspace = true
datafusion-physical-expr.workspace = true
datafusion-sql.workspace = true
datafusion.workspace = true
futures.workspace = true
half.workspace = true
itertools.workspace = true
lance-arrow.workspace = true
lance-core.workspace = true
lance-file.workspace = true
lance-datafusion.workspace = true
lance-file.workspace = true
lance-io.workspace = true
lance-linalg.workspace = true
lance-table.workspace = true
Expand Down
37 changes: 34 additions & 3 deletions rust/lance-index/src/vector/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ pub(crate) mod builder;
pub mod memory;
pub(super) mod storage;

use storage::VectorStorage;
/// Vector storage to back a graph.
pub use storage::VectorStorage;

pub(crate) const NEIGHBORS_COL: &str = "__neighbors";

Expand Down Expand Up @@ -87,6 +88,36 @@ impl From<OrderedFloat> for f32 {
}
}

#[derive(Debug, Eq, PartialEq, Clone)]
pub(crate) struct OrderedNode {
pub id: u32,
pub dist: OrderedFloat,
}

impl PartialOrd for OrderedNode {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.dist.cmp(&other.dist))
}
}

impl Ord for OrderedNode {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.dist.cmp(&other.dist)
}
}

impl From<(OrderedFloat, u32)> for OrderedNode {
fn from((dist, id): (OrderedFloat, u32)) -> Self {
Self { id, dist }
}
}

impl From<OrderedNode> for (OrderedFloat, u32) {
fn from(node: OrderedNode) -> Self {
(node.dist, node.id)
}
}

/// Distance calculator.
///
/// This trait is used to calculate a query vector to a stream of vector IDs.
Expand All @@ -113,7 +144,7 @@ pub trait Graph {
}

/// Get the neighbors of a graph node, identifyied by the index.
fn neighbors(&self, key: u32) -> Option<Box<dyn Iterator<Item = &u32> + '_>>;
fn neighbors(&self, key: u32) -> Option<Box<dyn Iterator<Item = u32> + '_>>;

/// Access to underline storage
fn storage(&self) -> Arc<dyn VectorStorage>;
Expand Down Expand Up @@ -163,7 +194,7 @@ pub(super) fn beam_search(
location: location!(),
})?;

for &neighbor in neighbors {
for neighbor in neighbors {
if visited.contains(&neighbor) {
continue;
}
Expand Down
75 changes: 48 additions & 27 deletions rust/lance-index/src/vector/graph/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::BTreeMap;
use std::collections::{BinaryHeap, HashMap};
use std::sync::Arc;

use lance_core::{Error, Result};
use snafu::{location, Location};

use super::OrderedNode;
use super::{memory::InMemoryVectorStorage, Graph, GraphNode, OrderedFloat};
use crate::vector::graph::storage::VectorStorage;

Expand All @@ -26,44 +27,48 @@ use crate::vector::graph::storage::VectorStorage;
pub struct GraphBuilderNode {
/// Node ID
pub(crate) id: u32,
/// Neighbors, sorted by the distance.
pub(crate) neighbors: BTreeMap<OrderedFloat, u32>,

/// Pointer to the next level of graph, or acts as the idx
pub pointer: u32,
/// Neighbors, sorted by the distance.
pub(crate) neighbors: BinaryHeap<OrderedNode>,
}

impl GraphBuilderNode {
fn new(id: u32) -> Self {
Self {
id,
neighbors: BTreeMap::new(),
pointer: 0,
neighbors: BinaryHeap::new(),
}
}

fn add_neighbor(&mut self, distance: f32, id: u32) {
self.neighbors.push(OrderedNode {
dist: OrderedFloat(distance),
id,
});
}

/// Prune the node and only keep `max_edges` edges.
///
/// Returns the ids of pruned neighbors.
fn prune(&mut self, max_edges: usize) -> Vec<u32> {
if self.neighbors.len() <= max_edges {
return vec![];
}

let mut pruned = Vec::with_capacity(self.neighbors.len() - max_edges);
fn prune(&mut self, max_edges: usize) {
while self.neighbors.len() > max_edges {
let (_, node) = self.neighbors.pop_last().unwrap();
pruned.push(node)
self.neighbors.pop();
}
pruned
}
}

impl From<&GraphBuilderNode> for GraphNode<u32> {
fn from(node: &GraphBuilderNode) -> Self {
let neighbors = node
.neighbors
.clone()
.into_sorted_vec()
.into_iter()
.map(|n| n.id)
.collect::<Vec<_>>();
Self {
id: node.id,
neighbors: node.neighbors.values().copied().collect(),
neighbors,
}
}
}
Expand All @@ -73,7 +78,7 @@ impl From<&GraphBuilderNode> for GraphNode<u32> {
/// [GraphBuilder] is used to build a graph in memory.
///
pub struct GraphBuilder {
pub(crate) nodes: BTreeMap<u32, GraphBuilderNode>,
pub(crate) nodes: HashMap<u32, GraphBuilderNode>,

/// Storage for vectors.
vectors: Arc<InMemoryVectorStorage>,
Expand All @@ -84,9 +89,15 @@ impl Graph for GraphBuilder {
self.nodes.len()
}

fn neighbors(&self, key: u32) -> Option<Box<dyn Iterator<Item = &u32> + '_>> {
fn neighbors(&self, key: u32) -> Option<Box<dyn Iterator<Item = u32> + '_>> {
let node = self.nodes.get(&key)?;
Some(Box::new(node.neighbors.values()))
Some(Box::new(
node.neighbors
.clone()
.into_sorted_vec()
.into_iter()
.map(|n| n.id),
))
}

fn storage(&self) -> Arc<dyn VectorStorage> {
Expand All @@ -98,7 +109,7 @@ impl GraphBuilder {
/// Build from a [VectorStorage].
pub fn new(vectors: Arc<InMemoryVectorStorage>) -> Self {
Self {
nodes: BTreeMap::new(),
nodes: HashMap::new(),
vectors,
}
}
Expand All @@ -110,22 +121,22 @@ impl GraphBuilder {

/// Connect from one node to another.
pub fn connect(&mut self, from: u32, to: u32) -> Result<()> {
let distance: OrderedFloat = self.vectors.distance_between(from, to).into();
let distance = self.vectors.distance_between(from, to);

{
let from_node = self.nodes.get_mut(&from).ok_or_else(|| Error::Index {
message: format!("Node {} not found", from),
location: location!(),
})?;
from_node.neighbors.insert(distance, to);
from_node.add_neighbor(distance, to)
}

{
let to_node = self.nodes.get_mut(&to).ok_or_else(|| Error::Index {
message: format!("Node {} not found", to),
location: location!(),
})?;
to_node.neighbors.insert(distance, from);
to_node.add_neighbor(distance, from);
}
Ok(())
}
Expand All @@ -149,7 +160,7 @@ impl GraphBuilder {
let edges = node.neighbors.len();
total_edges += edges;
max_edges = max_edges.max(edges);
total_distance += node.neighbors.keys().map(|d| d.0).sum::<f32>();
total_distance += node.neighbors.iter().map(|n| n.dist.0).sum::<f32>();
}

GraphBuilderStats {
Expand Down Expand Up @@ -189,7 +200,17 @@ mod tests {
builder.connect(0, 1).unwrap();
assert_eq!(builder.len(), 2);

assert_eq!(builder.neighbors(0).unwrap().collect::<Vec<_>>(), vec![&1]);
assert_eq!(builder.neighbors(1).unwrap().collect::<Vec<_>>(), vec![&0]);
assert_eq!(builder.neighbors(0).unwrap().collect::<Vec<_>>(), vec![1]);
assert_eq!(builder.neighbors(1).unwrap().collect::<Vec<_>>(), vec![0]);

builder.insert(4);
builder.connect(0, 4).unwrap();
assert_eq!(builder.len(), 3);

assert_eq!(
builder.neighbors(0).unwrap().collect::<Vec<_>>(),
vec![1, 4]
);
assert_eq!(builder.neighbors(1).unwrap().collect::<Vec<_>>(), vec![0]);
}
}
5 changes: 5 additions & 0 deletions rust/lance-index/src/vector/graph/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ pub trait DistCalculator {
pub trait VectorStorage: Send + Sync {
fn len(&self) -> usize;

/// Returns true if this graph is empty.
fn is_empty(&self) -> bool {
self.len() == 0
}

/// Return the metric type of the vectors.
fn metric_type(&self) -> MetricType;

Expand Down
Loading

0 comments on commit 4fe5e43

Please sign in to comment.