Skip to content

Commit

Permalink
[ENH] make hnsw query orchestrator use BF operator (#1927)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
- This is a low-quality commit of wiring up the BF operator just to get
things working e2e. We will rewrite a lot of this marginally when the
chunk abstraction lands.
 - New functionality
	 - None

## Test plan
*How are these changes tested?*
- [x] Tests pass locally with `pytest` for python, `yarn test` for js,
`cargo test` for rust

## Documentation Changes
None
  • Loading branch information
HammadB authored Mar 24, 2024
1 parent a7cf00d commit 5369630
Showing 1 changed file with 67 additions and 14 deletions.
81 changes: 67 additions & 14 deletions rust/worker/src/execution/orchestration/hnsw.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
use super::super::operator::{wrap, TaskMessage};
use super::super::operators::pull_log::{PullLogsInput, PullLogsOperator, PullLogsOutput};
use crate::distance;
use crate::distance::DistanceFunction;
use crate::errors::ChromaError;
use crate::execution::operators::brute_force_knn::{
BruteForceKnnOperator, BruteForceKnnOperatorInput, BruteForceKnnOperatorOutput,
BruteForceKnnOperatorResult,
};
use crate::execution::operators::pull_log::PullLogsResult;
use crate::sysdb::sysdb::SysDb;
use crate::system::System;
Expand All @@ -12,7 +18,6 @@ use crate::{
use async_trait::async_trait;
use num_bigint::BigInt;
use std::fmt::Debug;
use std::fmt::Formatter;
use std::time::{SystemTime, UNIX_EPOCH};
use uuid::Uuid;

Expand Down Expand Up @@ -118,7 +123,7 @@ impl HnswQueryOrchestrator {
let end_timestamp = SystemTime::now().duration_since(UNIX_EPOCH);
let end_timestamp = match end_timestamp {
// TODO: change protobuf definition to use u64 instead of i64
Ok(end_timestamp) => end_timestamp.as_nanos() as i64,
Ok(end_timestamp) => end_timestamp.as_secs() as i64,
Err(e) => {
// Log an error and reply + return
return;
Expand Down Expand Up @@ -173,8 +178,45 @@ impl Handler<PullLogsResult> for HnswQueryOrchestrator {
self.state = ExecutionState::Dedupe;

// TODO: implement the remaining state transitions and operators
// This is an example of the final state transition and result
// TODO: don't need all this cloning and data shuffling, once we land the chunk abstraction
let mut dataset = Vec::new();
match message {
Ok(logs) => {
for log in logs.logs().iter() {
// TODO: only adds have embeddings, unwrap is fine for now
dataset.push(log.embedding.clone().unwrap());
}
let bf_input = BruteForceKnnOperatorInput {
data: dataset,
query: self.query_vectors[0].clone(),
k: self.k as usize,
distance_metric: DistanceFunction::Euclidean,
};
let operator = Box::new(BruteForceKnnOperator {});
let task = wrap(operator, bf_input, ctx.sender.as_receiver());
match self.dispatcher.send(task).await {
Ok(_) => (),
Err(e) => {
// TODO: log an error and reply to caller
}
}
}
Err(e) => {
// Log an error
return;
}
}
}
}

#[async_trait]
impl Handler<BruteForceKnnOperatorResult> for HnswQueryOrchestrator {
async fn handle(
&mut self,
message: BruteForceKnnOperatorResult,
ctx: &crate::system::ComponentContext<HnswQueryOrchestrator>,
) {
// This is an example of the final state transition and result
let result_channel = match self.result_channel.take() {
Some(tx) => tx,
None => {
Expand All @@ -184,18 +226,29 @@ impl Handler<PullLogsResult> for HnswQueryOrchestrator {
};

match message {
Ok(logs) => {
// TODO: remove this after debugging
println!("Received logs: {:?}", logs);
let _ = result_channel.send(Ok(vec![vec![VectorQueryResult {
id: "abc".to_string(),
seq_id: BigInt::from(0),
distance: 0.0,
vector: Some(vec![0.0, 0.0, 0.0]),
}]]));
Ok(output) => {
let mut result = Vec::new();
let mut query_results = Vec::new();
for (index, distance) in output.indices.iter().zip(output.distances.iter()) {
let query_result = VectorQueryResult {
id: index.to_string(),
seq_id: BigInt::from(0),
distance: *distance,
vector: None,
};
query_results.push(query_result);
}
result.push(query_results);

match result_channel.send(Ok(result)) {
Ok(_) => (),
Err(e) => {
// Log an error
}
}
}
Err(e) => {
let _ = result_channel.send(Err(Box::new(e)));
Err(_) => {
// Log an error
}
}
}
Expand Down

0 comments on commit 5369630

Please sign in to comment.