diff --git a/rust/worker/src/execution/orchestration/hnsw.rs b/rust/worker/src/execution/orchestration/hnsw.rs index d8d4d14fcda..97c0f3512d4 100644 --- a/rust/worker/src/execution/orchestration/hnsw.rs +++ b/rust/worker/src/execution/orchestration/hnsw.rs @@ -1,6 +1,12 @@ use super::super::operator::{wrap, TaskMessage}; use super::super::operators::pull_log::{PullLogsInput, PullLogsOperator, PullLogsOutput}; +use crate::distance; +use crate::distance::DistanceFunction; use crate::errors::ChromaError; +use crate::execution::operators::brute_force_knn::{ + BruteForceKnnOperator, BruteForceKnnOperatorInput, BruteForceKnnOperatorOutput, + BruteForceKnnOperatorResult, +}; use crate::execution::operators::pull_log::PullLogsResult; use crate::sysdb::sysdb::SysDb; use crate::system::System; @@ -12,7 +18,6 @@ use crate::{ use async_trait::async_trait; use num_bigint::BigInt; use std::fmt::Debug; -use std::fmt::Formatter; use std::time::{SystemTime, UNIX_EPOCH}; use uuid::Uuid; @@ -118,7 +123,7 @@ impl HnswQueryOrchestrator { let end_timestamp = SystemTime::now().duration_since(UNIX_EPOCH); let end_timestamp = match end_timestamp { // TODO: change protobuf definition to use u64 instead of i64 - Ok(end_timestamp) => end_timestamp.as_nanos() as i64, + Ok(end_timestamp) => end_timestamp.as_secs() as i64, Err(e) => { // Log an error and reply + return return; @@ -173,8 +178,45 @@ impl Handler for HnswQueryOrchestrator { self.state = ExecutionState::Dedupe; // TODO: implement the remaining state transitions and operators - // This is an example of the final state transition and result + // TODO: don't need all this cloning and data shuffling, once we land the chunk abstraction + let mut dataset = Vec::new(); + match message { + Ok(logs) => { + for log in logs.logs().iter() { + // TODO: only adds have embeddings, unwrap is fine for now + dataset.push(log.embedding.clone().unwrap()); + } + let bf_input = BruteForceKnnOperatorInput { + data: dataset, + query: self.query_vectors[0].clone(), + k: self.k as usize, + distance_metric: DistanceFunction::Euclidean, + }; + let operator = Box::new(BruteForceKnnOperator {}); + let task = wrap(operator, bf_input, ctx.sender.as_receiver()); + match self.dispatcher.send(task).await { + Ok(_) => (), + Err(e) => { + // TODO: log an error and reply to caller + } + } + } + Err(e) => { + // Log an error + return; + } + } + } +} +#[async_trait] +impl Handler for HnswQueryOrchestrator { + async fn handle( + &mut self, + message: BruteForceKnnOperatorResult, + ctx: &crate::system::ComponentContext, + ) { + // This is an example of the final state transition and result let result_channel = match self.result_channel.take() { Some(tx) => tx, None => { @@ -184,18 +226,29 @@ impl Handler for HnswQueryOrchestrator { }; match message { - Ok(logs) => { - // TODO: remove this after debugging - println!("Received logs: {:?}", logs); - let _ = result_channel.send(Ok(vec![vec![VectorQueryResult { - id: "abc".to_string(), - seq_id: BigInt::from(0), - distance: 0.0, - vector: Some(vec![0.0, 0.0, 0.0]), - }]])); + Ok(output) => { + let mut result = Vec::new(); + let mut query_results = Vec::new(); + for (index, distance) in output.indices.iter().zip(output.distances.iter()) { + let query_result = VectorQueryResult { + id: index.to_string(), + seq_id: BigInt::from(0), + distance: *distance, + vector: None, + }; + query_results.push(query_result); + } + result.push(query_results); + + match result_channel.send(Ok(result)) { + Ok(_) => (), + Err(e) => { + // Log an error + } + } } - Err(e) => { - let _ = result_channel.send(Err(Box::new(e))); + Err(_) => { + // Log an error } } }