From c95edd141d9cd72aa6bca6424f717b3635ff0780 Mon Sep 17 00:00:00 2001 From: Hammad Bashir Date: Thu, 2 May 2024 17:06:51 -0700 Subject: [PATCH] [ENH] Query merging (#2066) ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - / - New functionality - Completes the query merging flow by adding an HSNW operator and merge results operator that will rehydrate from the record segment blockfile. ## Test plan *How are these changes tested?* - [x] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Documentation Changes None --- rust/worker/Dockerfile | 2 +- rust/worker/bindings.cpp | 5 + .../src/execution/operators/hnsw_knn.rs | 36 ++ .../execution/operators/merge_knn_results.rs | 171 ++++++ rust/worker/src/execution/operators/mod.rs | 2 + .../src/execution/orchestration/compact.rs | 2 + .../src/execution/orchestration/hnsw.rs | 563 +++++++++++++++--- rust/worker/src/index/hnsw.rs | 17 +- .../src/segment/distributed_hnsw_segment.rs | 123 ++-- rust/worker/src/segment/record_segment.rs | 283 +++++++-- rust/worker/src/server.rs | 23 + rust/worker/src/storage/s3.rs | 1 - rust/worker/src/types/segment.rs | 12 + 13 files changed, 1044 insertions(+), 196 deletions(-) create mode 100644 rust/worker/src/execution/operators/hnsw_knn.rs create mode 100644 rust/worker/src/execution/operators/merge_knn_results.rs diff --git a/rust/worker/Dockerfile b/rust/worker/Dockerfile index 9ec799db8e2..a4e3e7579b6 100644 --- a/rust/worker/Dockerfile +++ b/rust/worker/Dockerfile @@ -36,7 +36,7 @@ FROM debian:bookworm-slim as query_service COPY --from=query_service_builder /chroma/query_service . COPY --from=query_service_builder /chroma/rust/worker/chroma_config.yaml . -RUN apt-get update && apt-get install -y libssl-dev +RUN apt-get update && apt-get install -y libssl-dev ca-certificates ENTRYPOINT [ "./query_service" ] diff --git a/rust/worker/bindings.cpp b/rust/worker/bindings.cpp index 982d14dd5d8..c296700ae6f 100644 --- a/rust/worker/bindings.cpp +++ b/rust/worker/bindings.cpp @@ -200,4 +200,9 @@ extern "C" { index->set_ef(ef); } + + int len(Index *index) + { + return index->appr_alg->getCurrentElementCount() - index->appr_alg->getDeletedCount(); + } } diff --git a/rust/worker/src/execution/operators/hnsw_knn.rs b/rust/worker/src/execution/operators/hnsw_knn.rs new file mode 100644 index 00000000000..188dbf1d4d8 --- /dev/null +++ b/rust/worker/src/execution/operators/hnsw_knn.rs @@ -0,0 +1,36 @@ +use crate::{ + errors::ChromaError, execution::operator::Operator, + segment::distributed_hnsw_segment::DistributedHNSWSegment, +}; +use async_trait::async_trait; + +#[derive(Debug)] +pub struct HnswKnnOperator {} + +#[derive(Debug)] +pub struct HnswKnnOperatorInput { + pub segment: Box, + pub query: Vec, + pub k: usize, +} + +#[derive(Debug)] +pub struct HnswKnnOperatorOutput { + pub offset_ids: Vec, + pub distances: Vec, +} + +pub type HnswKnnOperatorResult = Result>; + +#[async_trait] +impl Operator for HnswKnnOperator { + type Error = Box; + + async fn run(&self, input: &HnswKnnOperatorInput) -> HnswKnnOperatorResult { + let (offset_ids, distances) = input.segment.query(&input.query, input.k); + Ok(HnswKnnOperatorOutput { + offset_ids, + distances, + }) + } +} diff --git a/rust/worker/src/execution/operators/merge_knn_results.rs b/rust/worker/src/execution/operators/merge_knn_results.rs new file mode 100644 index 00000000000..31ce80622a4 --- /dev/null +++ b/rust/worker/src/execution/operators/merge_knn_results.rs @@ -0,0 +1,171 @@ +use std::f64::consts::E; + +use crate::{ + blockstore::provider::BlockfileProvider, + errors::ChromaError, + execution::operator::Operator, + segment::record_segment::{RecordSegmentReader, RecordSegmentReaderCreationError}, + types::Segment, +}; +use async_trait::async_trait; +use thiserror::Error; + +#[derive(Debug)] +pub struct MergeKnnResultsOperator {} + +#[derive(Debug)] +pub struct MergeKnnResultsOperatorInput { + hnsw_result_offset_ids: Vec, + hnsw_result_distances: Vec, + brute_force_result_user_ids: Vec, + brute_force_result_distances: Vec, + k: usize, + record_segment_definition: Segment, + blockfile_provider: BlockfileProvider, +} + +impl MergeKnnResultsOperatorInput { + pub fn new( + hnsw_result_offset_ids: Vec, + hnsw_result_distances: Vec, + brute_force_result_user_ids: Vec, + brute_force_result_distances: Vec, + k: usize, + record_segment_definition: Segment, + blockfile_provider: BlockfileProvider, + ) -> Self { + Self { + hnsw_result_offset_ids, + hnsw_result_distances, + brute_force_result_user_ids, + brute_force_result_distances, + k, + record_segment_definition, + blockfile_provider: blockfile_provider, + } + } +} + +#[derive(Debug)] +pub struct MergeKnnResultsOperatorOutput { + pub user_ids: Vec, + pub distances: Vec, +} + +#[derive(Error, Debug)] +pub enum MergeKnnResultsOperatorError {} + +impl ChromaError for MergeKnnResultsOperatorError { + fn code(&self) -> crate::errors::ErrorCodes { + return crate::errors::ErrorCodes::UNKNOWN; + } +} + +pub type MergeKnnResultsOperatorResult = + Result>; + +#[async_trait] +impl Operator + for MergeKnnResultsOperator +{ + type Error = Box; + + async fn run(&self, input: &MergeKnnResultsOperatorInput) -> MergeKnnResultsOperatorResult { + let (result_user_ids, result_distances) = match RecordSegmentReader::from_segment( + &input.record_segment_definition, + &input.blockfile_provider, + ) + .await + { + Ok(reader) => { + println!("Record Segment Reader created successfully"); + // Convert the HNSW result offset IDs to user IDs + let mut hnsw_result_user_ids = Vec::new(); + for offset_id in &input.hnsw_result_offset_ids { + let user_id = reader.get_user_id_for_offset_id(*offset_id as u32).await; + match user_id { + Ok(user_id) => hnsw_result_user_ids.push(user_id), + Err(e) => return Err(e), + } + } + merge_results( + &hnsw_result_user_ids, + &input.hnsw_result_distances, + &input.brute_force_result_user_ids, + &input.brute_force_result_distances, + input.k, + ) + } + Err(e) => match *e { + RecordSegmentReaderCreationError::BlockfileOpenError(e) => { + return Err(e); + } + RecordSegmentReaderCreationError::InvalidNumberOfFiles => { + return Err(e); + } + RecordSegmentReaderCreationError::UninitializedSegment => { + // The record segment doesn't exist - which implies no HNSW results + let hnsw_result_user_ids = Vec::new(); + let hnsw_result_distances = Vec::new(); + merge_results( + &hnsw_result_user_ids, + &hnsw_result_distances, + &input.brute_force_result_user_ids, + &input.brute_force_result_distances, + input.k, + ) + } + }, + }; + + Ok(MergeKnnResultsOperatorOutput { + user_ids: result_user_ids, + distances: result_distances, + }) + } +} + +fn merge_results( + hnsw_result_user_ids: &Vec<&str>, + hnsw_result_distances: &Vec, + brute_force_result_user_ids: &Vec, + brute_force_result_distances: &Vec, + k: usize, +) -> (Vec, Vec) { + let mut result_user_ids = Vec::with_capacity(k); + let mut result_distances = Vec::with_capacity(k); + + // Merge the HNSW and brute force results together by the minimum distance top k + let mut hnsw_index = 0; + let mut brute_force_index = 0; + + // TODO: This doesn't have to clone the user IDs, but it's easier for now + while (result_user_ids.len() <= k) + && (hnsw_index < hnsw_result_user_ids.len() + || brute_force_index < brute_force_result_user_ids.len()) + { + if hnsw_index < hnsw_result_user_ids.len() + && brute_force_index < brute_force_result_user_ids.len() + { + if hnsw_result_distances[hnsw_index] < brute_force_result_distances[brute_force_index] { + result_user_ids.push(hnsw_result_user_ids[hnsw_index].to_string()); + result_distances.push(hnsw_result_distances[hnsw_index]); + hnsw_index += 1; + } else { + result_user_ids.push(brute_force_result_user_ids[brute_force_index].to_string()); + result_distances.push(brute_force_result_distances[brute_force_index]); + brute_force_index += 1; + } + } else if hnsw_index < hnsw_result_user_ids.len() { + result_user_ids.push(hnsw_result_user_ids[hnsw_index].to_string()); + result_distances.push(hnsw_result_distances[hnsw_index]); + hnsw_index += 1; + } else if brute_force_index < brute_force_result_user_ids.len() { + result_user_ids.push(brute_force_result_user_ids[brute_force_index].to_string()); + result_distances.push(brute_force_result_distances[brute_force_index]); + brute_force_index += 1; + } + } + + (result_user_ids, result_distances) +} diff --git a/rust/worker/src/execution/operators/mod.rs b/rust/worker/src/execution/operators/mod.rs index 6bd208aab66..871cb2d7b7d 100644 --- a/rust/worker/src/execution/operators/mod.rs +++ b/rust/worker/src/execution/operators/mod.rs @@ -1,5 +1,7 @@ pub(super) mod brute_force_knn; pub(super) mod flush_s3; +pub(super) mod hnsw_knn; +pub(super) mod merge_knn_results; pub(super) mod normalize_vectors; pub(super) mod partition; pub(super) mod pull_log; diff --git a/rust/worker/src/execution/orchestration/compact.rs b/rust/worker/src/execution/orchestration/compact.rs index c7021991bf4..da2cd3d497a 100644 --- a/rust/worker/src/execution/orchestration/compact.rs +++ b/rust/worker/src/execution/orchestration/compact.rs @@ -178,6 +178,8 @@ impl CompactOrchestrator { }; let input = PullLogsInput::new( collection_id, + // Here we do not need to be inclusive since the compaction job + // offset is the one after the last compaction offset self.compaction_job.offset, 100, None, diff --git a/rust/worker/src/execution/orchestration/hnsw.rs b/rust/worker/src/execution/orchestration/hnsw.rs index fb17da05a55..b3ad7e8ed8b 100644 --- a/rust/worker/src/execution/orchestration/hnsw.rs +++ b/rust/worker/src/execution/orchestration/hnsw.rs @@ -1,14 +1,24 @@ use super::super::operator::{wrap, TaskMessage}; use super::super::operators::pull_log::{PullLogsInput, PullLogsOperator}; +use crate::blockstore::provider::BlockfileProvider; use crate::distance::DistanceFunction; -use crate::errors::ChromaError; +use crate::errors::{ChromaError, ErrorCodes}; +use crate::execution::data::data_chunk::Chunk; use crate::execution::operators::brute_force_knn::{ BruteForceKnnOperator, BruteForceKnnOperatorInput, BruteForceKnnOperatorResult, }; +use crate::execution::operators::hnsw_knn::{ + HnswKnnOperator, HnswKnnOperatorInput, HnswKnnOperatorResult, +}; +use crate::execution::operators::merge_knn_results::{ + MergeKnnResultsOperator, MergeKnnResultsOperatorInput, MergeKnnResultsOperatorResult, +}; use crate::execution::operators::pull_log::PullLogsResult; -use crate::sysdb::sysdb::SysDb; -use crate::system::System; -use crate::types::VectorQueryResult; +use crate::index::hnsw_provider::HnswIndexProvider; +use crate::segment::distributed_hnsw_segment::DistributedHNSWSegment; +use crate::sysdb::sysdb::{GetCollectionsError, GetSegmentsError, SysDb}; +use crate::system::{ComponentContext, System}; +use crate::types::{Collection, LogRecord, Segment, SegmentType, VectorQueryResult}; use crate::{ log::log::Log, system::{Component, Handler, Receiver}, @@ -16,6 +26,7 @@ use crate::{ use async_trait::async_trait; use std::fmt::Debug; use std::time::{SystemTime, UNIX_EPOCH}; +use thiserror::Error; use tracing::{trace, trace_span, Instrument, Span}; use uuid::Uuid; @@ -28,7 +39,7 @@ understand. We can always add more abstraction later if we need it. ┌───► Brute Force ─────┐ │ │ - Pending ─► PullLogs ─► Group│ ├─► MergeResults ─► Finished + Pending ─► PullLogs ─► Group │ ├─► MergeResults ─► Finished │ │ └───► HNSW ────────────┘ @@ -39,11 +50,43 @@ enum ExecutionState { Pending, PullLogs, Partition, - QueryKnn, + QueryKnn, // This is both the Brute force and HNSW query state MergeResults, Finished, } +#[derive(Error, Debug)] +enum HnswSegmentQueryError { + #[error("Hnsw segment with id: {0} not found")] + HnswSegmentNotFound(Uuid), + #[error("Get segments error")] + GetSegmentsError(#[from] GetSegmentsError), + #[error("Collection: {0} not found")] + CollectionNotFound(Uuid), + #[error("Get collection error")] + GetCollectionError(#[from] GetCollectionsError), + #[error("Record segment not found for collection: {0}")] + RecordSegmentNotFound(Uuid), + #[error("HNSW segment has no collection")] + HnswSegmentHasNoCollection, + #[error("Collection has no dimension set")] + CollectionHasNoDimension, +} + +impl ChromaError for HnswSegmentQueryError { + fn code(&self) -> ErrorCodes { + match self { + HnswSegmentQueryError::HnswSegmentNotFound(_) => ErrorCodes::NotFound, + HnswSegmentQueryError::GetSegmentsError(_) => ErrorCodes::Internal, + HnswSegmentQueryError::CollectionNotFound(_) => ErrorCodes::NotFound, + HnswSegmentQueryError::GetCollectionError(_) => ErrorCodes::Internal, + HnswSegmentQueryError::RecordSegmentNotFound(_) => ErrorCodes::NotFound, + HnswSegmentQueryError::HnswSegmentHasNoCollection => ErrorCodes::InvalidArgument, + HnswSegmentQueryError::CollectionHasNoDimension => ErrorCodes::InvalidArgument, + } + } +} + #[derive(Debug)] pub(crate) struct HnswQueryOrchestrator { state: ExecutionState, @@ -53,11 +96,23 @@ pub(crate) struct HnswQueryOrchestrator { query_vectors: Vec>, k: i32, include_embeddings: bool, - segment_id: Uuid, + hnsw_segment_id: Uuid, + // State fetched or created for query execution + hnsw_segment: Option, + record_segment: Option, + collection: Option, + hnsw_result_offset_ids: Option>, + hnsw_result_distances: Option>, + brute_force_result_user_ids: Option>, + brute_force_result_distances: Option>, + // State machine management + merge_dependency_count: u32, // Services log: Box, sysdb: Box, dispatcher: Box>, + hnsw_index_provider: HnswIndexProvider, + blockfile_provider: BlockfileProvider, // Result channel result_channel: Option< tokio::sync::oneshot::Sender>, Box>>, @@ -73,60 +128,37 @@ impl HnswQueryOrchestrator { segment_id: Uuid, log: Box, sysdb: Box, + hnsw_index_provider: HnswIndexProvider, + blockfile_provider: BlockfileProvider, dispatcher: Box>, ) -> Self { HnswQueryOrchestrator { state: ExecutionState::Pending, system, + merge_dependency_count: 2, query_vectors, k, include_embeddings, - segment_id, + hnsw_segment_id: segment_id, + hnsw_segment: None, + record_segment: None, + collection: None, + hnsw_result_offset_ids: None, + hnsw_result_distances: None, + brute_force_result_user_ids: None, + brute_force_result_distances: None, log, sysdb, dispatcher, + hnsw_index_provider, + blockfile_provider, result_channel: None, } } - /// Get the collection id for a segment id. - /// TODO: This can be cached - async fn get_collection_id_for_segment_id(&mut self, segment_id: Uuid) -> Option { - let segments = self - .sysdb - .get_segments(Some(segment_id), None, None, None) - .await; - match segments { - Ok(segments) => match segments.get(0) { - Some(segment) => { - trace!("Collection Id {:?}", segment.collection); - segment.collection - } - None => None, - }, - Err(e) => { - // Log an error and return - return None; - } - } - } - async fn pull_logs(&mut self, self_address: Box>) { self.state = ExecutionState::PullLogs; let operator = PullLogsOperator::new(self.log.clone()); - let child_span: tracing::Span = - trace_span!(parent: Span::current(), "get collection id for segment id"); - let get_collection_id_future = self.get_collection_id_for_segment_id(self.segment_id); - let collection_id = match get_collection_id_future - .instrument(child_span.clone()) - .await - { - Some(collection_id) => collection_id, - None => { - // Log an error and reply + return - return; - } - }; let end_timestamp = SystemTime::now().duration_since(UNIX_EPOCH); let end_timestamp = match end_timestamp { // TODO: change protobuf definition to use u64 instead of i64 @@ -136,11 +168,24 @@ impl HnswQueryOrchestrator { return; } }; - let input = PullLogsInput::new(collection_id, 0, 100, None, Some(end_timestamp)); + + let collection = self + .collection + .as_ref() + .expect("State machine invariant violation. The collection is not set when pulling logs. This should never happen."); + + let input = PullLogsInput::new( + collection.id, + // The collection log position is inclusive, and we want to start from the next log + collection.log_position + 1, + 100, + None, + Some(end_timestamp), + ); let task = wrap(operator, input, self_address); // Wrap the task with current span as the parent. The worker then executes it // inside a child span with this parent. - match self.dispatcher.send(task, Some(child_span.clone())).await { + match self.dispatcher.send(task, Some(Span::current())).await { Ok(_) => (), Err(e) => { // TODO: log an error and reply to caller @@ -148,6 +193,237 @@ impl HnswQueryOrchestrator { } } + async fn brute_force_query( + &mut self, + logs: Chunk, + self_address: Box>, + ) { + self.state = ExecutionState::QueryKnn; + + // TODO: We shouldn't have to clone query vectors here. We should be able to pass a Arc<[f32]>-like to the input + let bf_input = BruteForceKnnOperatorInput { + data: logs, + query: self.query_vectors[0].clone(), + k: self.k as usize, + // TODO: get the distance metric from the segment metadata + distance_metric: DistanceFunction::Euclidean, + }; + let operator = Box::new(BruteForceKnnOperator {}); + let task = wrap(operator, bf_input, self_address); + match self.dispatcher.send(task, Some(Span::current())).await { + Ok(_) => (), + Err(e) => { + // TODO: log an error and reply to caller + } + } + } + + async fn hnsw_segment_query(&mut self, ctx: &ComponentContext) { + self.state = ExecutionState::QueryKnn; + + let hnsw_segment = self + .hnsw_segment + .as_ref() + .expect("Invariant violation. HNSW Segment is not set"); + let dimensionality = self + .collection + .as_ref() + .expect("Invariant violation. Collection is not set") + .dimension + .expect("Invariant violation. Collection dimension is not set"); + + // Fetch the data needed for the duration of the query - The HNSW Segment, The record Segment and the Collection + let hnsw_segment_reader = match DistributedHNSWSegment::from_segment( + // These unwraps are safe because we have already checked that the segments are set in the orchestrator on_start + hnsw_segment, + dimensionality as usize, + self.hnsw_index_provider.clone(), + ) + .await + { + Ok(reader) => reader, + Err(e) => { + self.terminate_with_error(e, ctx); + return; + } + }; + + println!("Created HNSW Segment Reader: {:?}", hnsw_segment_reader); + + // Dispatch a query task + let operator = Box::new(HnswKnnOperator {}); + let input = HnswKnnOperatorInput { + segment: hnsw_segment_reader, + query: self.query_vectors[0].clone(), + k: self.k as usize, + }; + let task = wrap(operator, input, ctx.sender.as_receiver()); + match self.dispatcher.send(task, Some(Span::current())).await { + Ok(_) => (), + Err(e) => { + // Log an error + println!("Error sending HNSW KNN task: {:?}", e); + } + } + } + + async fn merge_results(&mut self, ctx: &ComponentContext) { + self.state = ExecutionState::MergeResults; + + let record_segment = self + .record_segment + .as_ref() + .expect("Invariant violation. Record Segment is not set"); + + let operator = Box::new(MergeKnnResultsOperator {}); + let input = MergeKnnResultsOperatorInput::new( + self.hnsw_result_offset_ids + .as_ref() + .expect("Invariant violation. HNSW result offset ids are not set") + .clone(), + self.hnsw_result_distances + .as_ref() + .expect("Invariant violation. HNSW result distances are not set") + .clone(), + self.brute_force_result_user_ids + .as_ref() + .expect("Invariant violation. Brute force result user ids are not set") + .clone(), + self.brute_force_result_distances + .as_ref() + .expect("Invariant violation. Brute force result distances are not set") + .clone(), + self.k as usize, + record_segment.clone(), + self.blockfile_provider.clone(), + ); + + let task = wrap(operator, input, ctx.sender.as_receiver()); + match self.dispatcher.send(task, Some(Span::current())).await { + Ok(_) => (), + Err(e) => { + // Log an error + println!("Error sending Merge KNN task: {:?}", e); + } + } + } + + async fn get_hnsw_segment_from_id( + &self, + mut sysdb: Box, + hnsw_segment_id: &Uuid, + ) -> Result> { + let segments = sysdb + .get_segments(Some(*hnsw_segment_id), None, None, None) + .await; + let segment = match segments { + Ok(segments) => { + if segments.is_empty() { + return Err(Box::new(HnswSegmentQueryError::HnswSegmentNotFound( + *hnsw_segment_id, + ))); + } + segments[0].clone() + } + Err(e) => { + return Err(Box::new(HnswSegmentQueryError::GetSegmentsError(e))); + } + }; + + if segment.r#type != SegmentType::HnswDistributed { + return Err(Box::new(HnswSegmentQueryError::HnswSegmentNotFound( + *hnsw_segment_id, + ))); + } + Ok(segment) + } + + async fn get_collection( + &self, + mut sysdb: Box, + collection_id: &Uuid, + ) -> Result> { + let child_span: tracing::Span = + trace_span!(parent: Span::current(), "get collection id for segment id"); + let collections = sysdb + .get_collections(Some(*collection_id), None, None, None) + .instrument(child_span.clone()) + .await; + match collections { + Ok(mut collections) => { + if collections.is_empty() { + return Err(Box::new(HnswSegmentQueryError::CollectionNotFound( + *collection_id, + ))); + } + Ok(collections.drain(..).next().unwrap()) + } + Err(e) => { + return Err(Box::new(HnswSegmentQueryError::GetCollectionError(e))); + } + } + } + + async fn get_record_segment_for_collection( + &self, + mut sysdb: Box, + collection_id: &Uuid, + ) -> Result> { + let segments = sysdb + .get_segments( + None, + Some(SegmentType::Record.into()), + None, + Some(*collection_id), + ) + .await; + + let segment = match segments { + Ok(mut segments) => { + if segments.is_empty() { + println!( + "1. Record segment not found for collection: {:?}", + collection_id + ); + return Err(Box::new(HnswSegmentQueryError::RecordSegmentNotFound( + *collection_id, + ))); + } + segments.drain(..).next().unwrap() + } + Err(e) => { + return Err(Box::new(HnswSegmentQueryError::GetSegmentsError(e))); + } + }; + + if segment.r#type != SegmentType::Record { + println!( + "2. Record segment not found for collection: {:?}", + collection_id + ); + return Err(Box::new(HnswSegmentQueryError::RecordSegmentNotFound( + *collection_id, + ))); + } + Ok(segment) + } + + fn terminate_with_error(&mut self, error: Box, ctx: &ComponentContext) { + let result_channel = self + .result_channel + .take() + .expect("Invariant violation. Result channel is not set."); + match result_channel.send(Err(error)) { + Ok(_) => (), + Err(e) => { + // Log an error - this implied the listener was dropped + println!("[HnswQueryOrchestrator] Result channel dropped before sending error"); + } + } + // Cancel the orchestrator so it stops processing + ctx.cancellation_token.cancel(); + } + /// Run the orchestrator and return the result. /// # Note /// Use this over spawning the component directly. This method will start the component and @@ -171,6 +447,62 @@ impl Component for HnswQueryOrchestrator { } async fn on_start(&mut self, ctx: &crate::system::ComponentContext) -> () { + // Populate the orchestrator with the initial state - The HNSW Segment, The Record Segment and the Collection + let hnsw_segment = match self + .get_hnsw_segment_from_id(self.sysdb.clone(), &self.hnsw_segment_id) + .await + { + Ok(segment) => segment, + Err(e) => { + self.terminate_with_error(e, ctx); + return; + } + }; + + let collection_id = match &hnsw_segment.collection { + Some(collection_id) => collection_id, + None => { + self.terminate_with_error( + Box::new(HnswSegmentQueryError::HnswSegmentHasNoCollection), + ctx, + ); + return; + } + }; + + let collection = match self.get_collection(self.sysdb.clone(), collection_id).await { + Ok(collection) => collection, + Err(e) => { + self.terminate_with_error(e, ctx); + return; + } + }; + + // Validate that the collection has a dimension set. Downstream steps will rely on this + // so that they can unwrap the dimension without checking for None + if collection.dimension.is_none() { + self.terminate_with_error( + Box::new(HnswSegmentQueryError::CollectionHasNoDimension), + ctx, + ); + return; + }; + + let record_segment = match self + .get_record_segment_for_collection(self.sysdb.clone(), collection_id) + .await + { + Ok(segment) => segment, + Err(e) => { + self.terminate_with_error(e, ctx); + return; + } + }; + + self.record_segment = Some(record_segment); + self.hnsw_segment = Some(hnsw_segment); + self.collection = Some(collection); + self.pull_logs(ctx.sender.as_receiver()).await; } } @@ -186,32 +518,14 @@ impl Handler for HnswQueryOrchestrator { ) { self.state = ExecutionState::Partition; - // TODO: implement the remaining state transitions and operators - // TODO: don't need all this cloning and data shuffling, once we land the chunk abstraction match message { - Ok(logs) => { - let bf_input = BruteForceKnnOperatorInput { - data: logs.logs(), - query: self.query_vectors[0].clone(), - k: self.k as usize, - distance_metric: DistanceFunction::Euclidean, - }; - let operator = Box::new(BruteForceKnnOperator {}); - let task = wrap(operator, bf_input, ctx.sender.as_receiver()); - match self - .dispatcher - .send(task, Some(Span::current().clone())) - .await - { - Ok(_) => (), - Err(e) => { - // TODO: log an error and reply to caller - } - } + Ok(pull_logs_output) => { + self.brute_force_query(pull_logs_output.logs(), ctx.sender.as_receiver()) + .await; + self.hnsw_segment_query(ctx).await; } Err(e) => { - // Log an error - return; + self.terminate_with_error(Box::new(e), ctx); } } } @@ -222,40 +536,101 @@ impl Handler for HnswQueryOrchestrator { async fn handle( &mut self, message: BruteForceKnnOperatorResult, - _ctx: &crate::system::ComponentContext, + ctx: &crate::system::ComponentContext, + ) { + match message { + Ok(output) => { + let mut user_ids = Vec::new(); + for index in output.indices { + let record = match output.data.get(index) { + Some(record) => record, + None => { + // return an error + return; + } + }; + user_ids.push(record.record.id.clone()); + } + self.brute_force_result_user_ids = Some(user_ids); + self.brute_force_result_distances = Some(output.distances); + } + Err(e) => { + // TODO: handle this error, technically never happens + } + } + + self.merge_dependency_count -= 1; + + if self.merge_dependency_count == 0 { + // Trigger merge results + self.merge_results(ctx).await; + } + } +} + +#[async_trait] +impl Handler for HnswQueryOrchestrator { + async fn handle(&mut self, message: HnswKnnOperatorResult, ctx: &ComponentContext) { + self.merge_dependency_count -= 1; + + match message { + Ok(output) => { + self.hnsw_result_offset_ids = Some(output.offset_ids); + self.hnsw_result_distances = Some(output.distances); + } + Err(e) => { + self.terminate_with_error(e, ctx); + } + } + + if self.merge_dependency_count == 0 { + // Trigger merge results + self.merge_results(ctx).await; + } + } +} + +#[async_trait] +impl Handler for HnswQueryOrchestrator { + async fn handle( + &mut self, + message: MergeKnnResultsOperatorResult, + ctx: &crate::system::ComponentContext, ) { - // This is an example of the final state transition and result + self.state = ExecutionState::Finished; + + let (mut output_ids, mut output_distances) = match message { + Ok(output) => (output.user_ids, output.distances), + Err(e) => { + self.terminate_with_error(e, ctx); + return; + } + }; + + let mut result = Vec::new(); + let mut query_results = Vec::new(); + for (index, distance) in output_ids.drain(..).zip(output_distances.drain(..)) { + let query_result = VectorQueryResult { + id: index, + distance: distance, + vector: None, + }; + query_results.push(query_result); + } + result.push(query_results); + trace!("Merged results: {:?}", result); + let result_channel = match self.result_channel.take() { Some(tx) => tx, None => { - // Log an error + // Log an error - this is an invariant violation, the result channel should always be set return; } }; - match message { - Ok(output) => { - let mut result = Vec::new(); - let mut query_results = Vec::new(); - for (index, distance) in output.indices.iter().zip(output.distances.iter()) { - let query_result = VectorQueryResult { - id: index.to_string(), - distance: *distance, - vector: None, - }; - query_results.push(query_result); - } - result.push(query_results); - trace!("Merged results: {:?}", result); - - match result_channel.send(Ok(result)) { - Ok(_) => (), - Err(e) => { - // Log an error - } - } - } - Err(_) => { + match result_channel.send(Ok(result)) { + Ok(_) => (), + Err(e) => { // Log an error } } diff --git a/rust/worker/src/index/hnsw.rs b/rust/worker/src/index/hnsw.rs index 17e0bd5801e..3876d6d9375 100644 --- a/rust/worker/src/index/hnsw.rs +++ b/rust/worker/src/index/hnsw.rs @@ -203,8 +203,9 @@ impl Index for HnswIndex { } fn query(&self, vector: &[f32], k: usize) -> (Vec, Vec) { - let mut ids = vec![0usize; k]; - let mut distance = vec![0.0f32; k]; + let actual_k = std::cmp::min(k, self.len()); + let mut ids = vec![0usize; actual_k]; + let mut distance = vec![0.0f32; actual_k]; unsafe { knn_query( self.ffi_ptr, @@ -271,6 +272,10 @@ impl HnswIndex { pub fn get_ef(&self) -> usize { unsafe { get_ef(self.ffi_ptr) as usize } } + + pub fn len(&self) -> usize { + unsafe { len(self.ffi_ptr) as usize } + } } #[link(name = "bindings", kind = "static")] @@ -309,6 +314,7 @@ extern "C" { fn get_ef(index: *const IndexPtrFFI) -> c_int; fn set_ef(index: *const IndexPtrFFI, ef: c_int); + fn len(index: *const IndexPtrFFI) -> c_int; } @@ -357,7 +363,7 @@ pub mod test { #[test] fn it_can_add_parallel() { - let n = 10; + let n: usize = 100; let d: usize = 960; let distance_function = DistanceFunction::InnerProduct; let tmp_dir = tempdir().unwrap(); @@ -406,6 +412,8 @@ pub mod test { index.add(ids[i], data); }); + assert_eq!(index.len(), n); + // Get the data and check it let mut i = 0; for id in ids { @@ -461,6 +469,9 @@ pub mod test { index.add(ids[i], data); }); + // Assert length + assert_eq!(index.len(), n); + // Get the data and check it let mut i = 0; for id in ids { diff --git a/rust/worker/src/segment/distributed_hnsw_segment.rs b/rust/worker/src/segment/distributed_hnsw_segment.rs index 0a5dbf9bf5a..b2ebca84cc0 100644 --- a/rust/worker/src/segment/distributed_hnsw_segment.rs +++ b/rust/worker/src/segment/distributed_hnsw_segment.rs @@ -1,5 +1,5 @@ use super::{SegmentFlusher, SegmentWriter}; -use crate::errors::ChromaError; +use crate::errors::{ChromaError, ErrorCodes}; use crate::index::hnsw_provider::HnswIndexProvider; use crate::index::{HnswIndex, HnswIndexConfig, Index, IndexConfig}; use crate::types::{LogRecord, Operation, Segment}; @@ -7,8 +7,8 @@ use async_trait::async_trait; use parking_lot::RwLock; use std::collections::HashMap; use std::fmt::Debug; -use std::hash::Hash; use std::sync::Arc; +use thiserror::Error; use uuid::Uuid; const HNSW_INDEX: &str = "hnsw_index"; @@ -26,6 +26,23 @@ impl Debug for DistributedHNSWSegment { } } +#[derive(Error, Debug)] +pub enum DistributedHNSWSegmentFromSegmentError { + #[error("No hnsw file found for segment")] + NoHnswFileFound, + #[error("Hnsw file id not a valid uuid")] + InvalidUUID, +} + +impl ChromaError for DistributedHNSWSegmentFromSegmentError { + fn code(&self) -> crate::errors::ErrorCodes { + match self { + DistributedHNSWSegmentFromSegmentError::NoHnswFileFound => ErrorCodes::NotFound, + DistributedHNSWSegmentFromSegmentError::InvalidUUID => ErrorCodes::InvalidArgument, + } + } +} + impl DistributedHNSWSegment { pub(crate) fn new( index: Arc>, @@ -52,22 +69,49 @@ impl DistributedHNSWSegment { // ideally, an explicit state would be better. When we implement distributed HNSW segments, // we can introduce a state in the segment metadata for this if segment.file_path.len() > 0 { - // Load the index from the files - // TODO: we should not unwrap here - let index_id = &segment.file_path.get(HNSW_INDEX).unwrap()[0]; - let index_uuid = Uuid::parse_str(index_id.as_str()).unwrap(); - let index = hnsw_index_provider - .load(&index_uuid, segment, dimensionality as i32) - .await; - match index { - Ok(index) => Ok(Box::new(DistributedHNSWSegment::new( - index, - hnsw_index_provider, - segment.id, - )?)), - Err(e) => Err(e), - } + println!("Loading HNSW index from files"); + // Check if its in the providers cache, if not load the index from the files + let index_id = match &segment.file_path.get(HNSW_INDEX) { + None => { + return Err(Box::new( + DistributedHNSWSegmentFromSegmentError::NoHnswFileFound, + )) + } + Some(files) => { + if files.is_empty() { + return Err(Box::new( + DistributedHNSWSegmentFromSegmentError::NoHnswFileFound, + )); + } else { + &files[0] + } + } + }; + + let index_uuid = match Uuid::parse_str(index_id.as_str()) { + Ok(uuid) => uuid, + Err(_) => { + return Err(Box::new( + DistributedHNSWSegmentFromSegmentError::InvalidUUID, + )) + } + }; + + let index = match hnsw_index_provider.get(&index_uuid) { + Some(index) => index, + None => { + hnsw_index_provider + .load(&index_uuid, segment, dimensionality as i32) + .await? + } + }; + Ok(Box::new(DistributedHNSWSegment::new( + index, + hnsw_index_provider, + segment.id, + )?)) } else { + println!("Creating new HNSW index"); let index = hnsw_index_provider.create(segment, dimensionality as i32)?; Ok(Box::new(DistributedHNSWSegment::new( index, @@ -77,47 +121,10 @@ impl DistributedHNSWSegment { } } - // pub(crate) fn get_records(&self, ids: Vec) -> Vec> { - // let mut records = Vec::new(); - // let user_id_to_id = self.user_id_to_id.read(); - // let index = self.index.read(); - // for id in ids { - // let internal_id = match user_id_to_id.get(&id) { - // Some(internal_id) => internal_id, - // None => { - // // TODO: Error - // return records; - // } - // }; - // let vector = index.get(*internal_id); - // match vector { - // Some(vector) => { - // let record = VectorEmbeddingRecord { id: id, vector }; - // records.push(Box::new(record)); - // } - // None => { - // // TODO: error - // } - // } - // } - // return records; - // } - - // pub(crate) fn query(&self, vector: &[f32], k: usize) -> (Vec, Vec) { - // let index = self.index.read(); - // let mut return_user_ids = Vec::new(); - // let (ids, distances) = index.query(vector, k); - // let user_ids = self.id_to_user_id.read(); - // for id in ids { - // match user_ids.get(&id) { - // Some(user_id) => return_user_ids.push(user_id.clone()), - // None => { - // // TODO: error - // } - // }; - // } - // return (return_user_ids, distances); - // } + pub(crate) fn query(&self, vector: &[f32], k: usize) -> (Vec, Vec) { + let index = self.index.read(); + index.query(vector, k) + } } impl SegmentWriter for DistributedHNSWSegment { diff --git a/rust/worker/src/segment/record_segment.rs b/rust/worker/src/segment/record_segment.rs index eccae0d4b52..587233bcbc3 100644 --- a/rust/worker/src/segment/record_segment.rs +++ b/rust/worker/src/segment/record_segment.rs @@ -1,8 +1,8 @@ use super::types::{LogMaterializer, MaterializedLogRecord, SegmentWriter}; use super::{DataRecord, SegmentFlusher}; -use crate::blockstore::provider::{BlockfileProvider, CreateError}; +use crate::blockstore::provider::{BlockfileProvider, CreateError, OpenError}; use crate::blockstore::{BlockfileFlusher, BlockfileReader, BlockfileWriter}; -use crate::errors::ChromaError; +use crate::errors::{ChromaError, ErrorCodes}; use crate::execution::data::data_chunk::Chunk; use crate::types::{ update_metdata_to_metdata, LogRecord, Metadata, Operation, Segment, SegmentType, @@ -18,6 +18,7 @@ use uuid::Uuid; const USER_ID_TO_OFFSET_ID: &str = "user_id_to_offset_id"; const OFFSET_ID_TO_USER_ID: &str = "offset_id_to_user_id"; const OFFSET_ID_TO_DATA: &str = "offset_id_to_data"; +const MAX_OFFSET_ID: &str = "max_offset_id"; #[derive(Clone)] pub(crate) struct RecordSegmentWriter { @@ -25,7 +26,9 @@ pub(crate) struct RecordSegmentWriter { user_id_to_id: Option, id_to_user_id: Option, id_to_data: Option, - // TODO: store current max offset id in the metadata of the id_to_data blockfile + // TODO: for now we store the max offset ID in a separate blockfile, this is not ideal + // we should store it in metadata of one of the blockfiles + max_offset_id: Option, curr_max_offset_id: Arc, pub(crate) id: Uuid, // If there is an old version of the data, we need to keep it around to be able to @@ -40,7 +43,7 @@ impl Debug for RecordSegmentWriter { } #[derive(Error, Debug)] -pub enum RecordSegmentCreationError { +pub enum RecordSegmentWriterCreationError { #[error("Invalid segment type")] InvalidSegmentType, #[error("Missing file: {0}")] @@ -51,49 +54,70 @@ pub enum RecordSegmentCreationError { InvalidUuid(String), #[error("Blockfile Creation Error")] BlockfileCreateError(#[from] Box), + #[error("Blockfile Open Error")] + BlockfileOpenError(#[from] Box), + #[error("No exisiting offset id found")] + NoExistingOffsetId, } -impl<'a> RecordSegmentWriter { +impl RecordSegmentWriter { pub(crate) async fn from_segment( segment: &Segment, blockfile_provider: &BlockfileProvider, - ) -> Result { + ) -> Result { println!("Creating RecordSegmentWriter from Segment"); if segment.r#type != SegmentType::Record { - return Err(RecordSegmentCreationError::InvalidSegmentType); + return Err(RecordSegmentWriterCreationError::InvalidSegmentType); } - // RESUME POINT = HANDLE EXISTING FILES AND ALSO PORT HSNW TO FORK() NOT LOAD() - let (user_id_to_id, id_to_user_id, id_to_data) = match segment.file_path.len() { + let mut exising_max_offset_id = 0; + + let (user_id_to_id, id_to_user_id, id_to_data, max_offset_id) = match segment + .file_path + .len() + { 0 => { println!("No files found, creating new blockfiles for record segment"); let user_id_to_id = match blockfile_provider.create::<&str, u32>() { Ok(user_id_to_id) => user_id_to_id, - Err(e) => return Err(RecordSegmentCreationError::BlockfileCreateError(e)), + Err(e) => { + return Err(RecordSegmentWriterCreationError::BlockfileCreateError(e)) + } }; let id_to_user_id = match blockfile_provider.create::() { Ok(id_to_user_id) => id_to_user_id, - Err(e) => return Err(RecordSegmentCreationError::BlockfileCreateError(e)), + Err(e) => { + return Err(RecordSegmentWriterCreationError::BlockfileCreateError(e)) + } }; let id_to_data = match blockfile_provider.create::() { Ok(id_to_data) => id_to_data, - Err(e) => return Err(RecordSegmentCreationError::BlockfileCreateError(e)), + Err(e) => { + return Err(RecordSegmentWriterCreationError::BlockfileCreateError(e)) + } }; - (user_id_to_id, id_to_user_id, id_to_data) + let max_offset_id = match blockfile_provider.create::<&str, u32>() { + Ok(max_offset_id) => max_offset_id, + Err(e) => { + return Err(RecordSegmentWriterCreationError::BlockfileCreateError(e)) + } + }; + + (user_id_to_id, id_to_user_id, id_to_data, max_offset_id) } - 3 => { + 4 => { println!("Found files, loading blockfiles for record segment"); let user_id_to_id_bf_id = match segment.file_path.get(USER_ID_TO_OFFSET_ID) { Some(user_id_to_id_bf_id) => match user_id_to_id_bf_id.get(0) { Some(user_id_to_id_bf_id) => user_id_to_id_bf_id, None => { - return Err(RecordSegmentCreationError::MissingFile( + return Err(RecordSegmentWriterCreationError::MissingFile( USER_ID_TO_OFFSET_ID.to_string(), )) } }, None => { - return Err(RecordSegmentCreationError::MissingFile( + return Err(RecordSegmentWriterCreationError::MissingFile( USER_ID_TO_OFFSET_ID.to_string(), )) } @@ -102,13 +126,13 @@ impl<'a> RecordSegmentWriter { Some(id_to_user_id_bf_id) => match id_to_user_id_bf_id.get(0) { Some(id_to_user_id_bf_id) => id_to_user_id_bf_id, None => { - return Err(RecordSegmentCreationError::MissingFile( + return Err(RecordSegmentWriterCreationError::MissingFile( OFFSET_ID_TO_USER_ID.to_string(), )) } }, None => { - return Err(RecordSegmentCreationError::MissingFile( + return Err(RecordSegmentWriterCreationError::MissingFile( OFFSET_ID_TO_USER_ID.to_string(), )) } @@ -117,79 +141,129 @@ impl<'a> RecordSegmentWriter { Some(id_to_data_bf_id) => match id_to_data_bf_id.get(0) { Some(id_to_data_bf_id) => id_to_data_bf_id, None => { - return Err(RecordSegmentCreationError::MissingFile( + return Err(RecordSegmentWriterCreationError::MissingFile( OFFSET_ID_TO_DATA.to_string(), )) } }, None => { - return Err(RecordSegmentCreationError::MissingFile( + return Err(RecordSegmentWriterCreationError::MissingFile( OFFSET_ID_TO_DATA.to_string(), )) } }; + let max_offset_id_bf_id = match segment.file_path.get(MAX_OFFSET_ID) { + Some(max_offset_id_file_id) => match max_offset_id_file_id.get(0) { + Some(max_offset_id_file_id) => max_offset_id_file_id, + None => { + return Err(RecordSegmentWriterCreationError::MissingFile( + MAX_OFFSET_ID.to_string(), + )) + } + }, + None => { + return Err(RecordSegmentWriterCreationError::MissingFile( + MAX_OFFSET_ID.to_string(), + )) + } + }; let user_id_to_bf_uuid = match Uuid::parse_str(user_id_to_id_bf_id) { Ok(user_id_to_bf_uuid) => user_id_to_bf_uuid, - Err(e) => { - return Err(RecordSegmentCreationError::InvalidUuid( + Err(_) => { + return Err(RecordSegmentWriterCreationError::InvalidUuid( USER_ID_TO_OFFSET_ID.to_string(), )) } }; - let id_to_user_id_bf_uuid = match Uuid::parse_str(id_to_user_id_bf_id) { Ok(id_to_user_id_bf_uuid) => id_to_user_id_bf_uuid, - Err(e) => { - return Err(RecordSegmentCreationError::InvalidUuid( + Err(_) => { + return Err(RecordSegmentWriterCreationError::InvalidUuid( OFFSET_ID_TO_USER_ID.to_string(), )) } }; - let id_to_data_bf_uuid = match Uuid::parse_str(id_to_data_bf_id) { Ok(id_to_data_bf_uuid) => id_to_data_bf_uuid, - Err(e) => { - return Err(RecordSegmentCreationError::InvalidUuid( + Err(_) => { + return Err(RecordSegmentWriterCreationError::InvalidUuid( OFFSET_ID_TO_DATA.to_string(), )) } }; + let max_offset_id_bf_uuid = match Uuid::parse_str(max_offset_id_bf_id) { + Ok(max_offset_id_bf_uuid) => max_offset_id_bf_uuid, + Err(_) => { + return Err(RecordSegmentWriterCreationError::InvalidUuid( + MAX_OFFSET_ID.to_string(), + )) + } + }; let user_id_to_id = match blockfile_provider .fork::<&str, u32>(&user_id_to_bf_uuid) .await { Ok(user_id_to_id) => user_id_to_id, - Err(e) => return Err(RecordSegmentCreationError::BlockfileCreateError(e)), + Err(e) => { + return Err(RecordSegmentWriterCreationError::BlockfileCreateError(e)) + } }; - let id_to_user_id = match blockfile_provider .fork::(&id_to_user_id_bf_uuid) .await { Ok(id_to_user_id) => id_to_user_id, - Err(e) => return Err(RecordSegmentCreationError::BlockfileCreateError(e)), + Err(e) => { + return Err(RecordSegmentWriterCreationError::BlockfileCreateError(e)) + } }; - let id_to_data = match blockfile_provider .fork::(&id_to_data_bf_uuid) .await { Ok(id_to_data) => id_to_data, - Err(e) => return Err(RecordSegmentCreationError::BlockfileCreateError(e)), + Err(e) => { + return Err(RecordSegmentWriterCreationError::BlockfileCreateError(e)) + } + }; + let max_offset_id_bf = match blockfile_provider + .fork::<&str, u32>(&max_offset_id_bf_uuid) + .await + { + Ok(max_offset_id) => max_offset_id, + Err(e) => { + return Err(RecordSegmentWriterCreationError::BlockfileCreateError(e)) + } }; - (user_id_to_id, id_to_user_id, id_to_data) + let max_offset_id_bf_reader = match blockfile_provider + .open::<&str, u32>(&max_offset_id_bf_uuid) + .await + { + Ok(max_offset_id_bf_reader) => max_offset_id_bf_reader, + Err(e) => return Err(RecordSegmentWriterCreationError::BlockfileOpenError(e)), + }; + exising_max_offset_id = match max_offset_id_bf_reader.get("", MAX_OFFSET_ID).await { + Ok(max_offset_id) => max_offset_id, + Err(e) => { + return Err(RecordSegmentWriterCreationError::NoExistingOffsetId); + } + }; + + (user_id_to_id, id_to_user_id, id_to_data, max_offset_id_bf) } - _ => return Err(RecordSegmentCreationError::IncorrectNumberOfFiles), + _ => return Err(RecordSegmentWriterCreationError::IncorrectNumberOfFiles), }; + println!("Creating with max offset id: {}", exising_max_offset_id); Ok(RecordSegmentWriter { user_id_to_id: Some(user_id_to_id), id_to_user_id: Some(id_to_user_id), id_to_data: Some(id_to_data), - curr_max_offset_id: Arc::new(AtomicU32::new(0)), + max_offset_id: Some(max_offset_id), + curr_max_offset_id: Arc::new(AtomicU32::new(exising_max_offset_id + 1)), id: segment.id, }) } @@ -209,11 +283,11 @@ impl SegmentWriter for RecordSegmentWriter { let flusher_user_id_to_id = self.user_id_to_id.take().unwrap().commit::<&str, u32>(); let flusher_id_to_user_id = self.id_to_user_id.take().unwrap().commit::(); let flusher_id_to_data = self.id_to_data.take().unwrap().commit::(); + let flusher_max_offset_id = self.max_offset_id.take().unwrap().commit::<&str, u32>(); let flusher_user_id_to_id = match flusher_user_id_to_id { Ok(f) => f, Err(e) => { - // TOOD: log and return error return Err(e); } }; @@ -221,7 +295,6 @@ impl SegmentWriter for RecordSegmentWriter { let flusher_id_to_user_id = match flusher_id_to_user_id { Ok(f) => f, Err(e) => { - // TOOD: log and return error return Err(e); } }; @@ -229,7 +302,13 @@ impl SegmentWriter for RecordSegmentWriter { let flusher_id_to_data = match flusher_id_to_data { Ok(f) => f, Err(e) => { - // TOOD: log and return error + return Err(e); + } + }; + + let flusher_max_offset_id = match flusher_max_offset_id { + Ok(f) => f, + Err(e) => { return Err(e); } }; @@ -239,6 +318,7 @@ impl SegmentWriter for RecordSegmentWriter { user_id_to_id_flusher: flusher_user_id_to_id, id_to_user_id_flusher: flusher_id_to_user_id, id_to_data_flusher: flusher_id_to_data, + max_offset_id_flusher: flusher_max_offset_id, }) } } @@ -247,6 +327,7 @@ pub(crate) struct RecordSegmentFlusher { user_id_to_id_flusher: BlockfileFlusher, id_to_user_id_flusher: BlockfileFlusher, id_to_data_flusher: BlockfileFlusher, + max_offset_id_flusher: BlockfileFlusher, } impl Debug for RecordSegmentFlusher { @@ -261,9 +342,11 @@ impl SegmentFlusher for RecordSegmentFlusher { let user_id_to_id_bf_id = self.user_id_to_id_flusher.id(); let id_to_user_id_bf_id = self.id_to_user_id_flusher.id(); let id_to_data_bf_id = self.id_to_data_flusher.id(); + let max_offset_id_bf_id = self.max_offset_id_flusher.id(); let res_user_id_to_id = self.user_id_to_id_flusher.flush::<&str, u32>().await; let res_id_to_user_id = self.id_to_user_id_flusher.flush::().await; let res_id_to_data = self.id_to_data_flusher.flush::().await; + let res_max_offset_id = self.max_offset_id_flusher.flush::<&str, u32>().await; let mut flushed_files = HashMap::new(); @@ -303,6 +386,18 @@ impl SegmentFlusher for RecordSegmentFlusher { } } + match res_max_offset_id { + Ok(f) => { + flushed_files.insert( + MAX_OFFSET_ID.to_string(), + vec![max_offset_id_bf_id.to_string()], + ); + } + Err(e) => { + return Err(e); + } + } + Ok(flushed_files) } } @@ -321,6 +416,7 @@ impl LogMaterializer for RecordSegmentWriter { let next_offset_id = self .curr_max_offset_id .fetch_add(1, std::sync::atomic::Ordering::SeqCst); + let metadata = match &log_entry.record.metadata { Some(metadata) => match update_metdata_to_metdata(&metadata) { Ok(metadata) => Some(metadata), @@ -362,6 +458,13 @@ impl LogMaterializer for RecordSegmentWriter { .unwrap() .set("", next_offset_id, log_entry.record.id.as_str()) .await; + println!("Writing to max_offset_id: {}", next_offset_id); + let res = self + .max_offset_id + .as_ref() + .unwrap() + .set("", MAX_OFFSET_ID, next_offset_id) + .await; // TODO: use res materialized_records.push(materialized); } @@ -374,3 +477,105 @@ impl LogMaterializer for RecordSegmentWriter { Chunk::new(materialized_records.into()) } } + +pub(crate) struct RecordSegmentReader<'me> { + user_id_to_id: BlockfileReader<'me, &'me str, u32>, + id_to_user_id: BlockfileReader<'me, u32, &'me str>, + id_to_data: BlockfileReader<'me, u32, DataRecord<'me>>, +} + +#[derive(Error, Debug)] +pub enum RecordSegmentReaderCreationError { + #[error("Segment uninitialized")] + UninitializedSegment, + #[error("Blockfile Open Error")] + BlockfileOpenError(#[from] Box), + #[error("Segment has invalid number of files")] + InvalidNumberOfFiles, +} + +impl ChromaError for RecordSegmentReaderCreationError { + fn code(&self) -> ErrorCodes { + match self { + RecordSegmentReaderCreationError::BlockfileOpenError(e) => e.code(), + RecordSegmentReaderCreationError::InvalidNumberOfFiles => ErrorCodes::InvalidArgument, + RecordSegmentReaderCreationError::UninitializedSegment => ErrorCodes::InvalidArgument, + } + } +} + +impl RecordSegmentReader<'_> { + pub(crate) async fn from_segment( + segment: &Segment, + blockfile_provider: &BlockfileProvider, + ) -> Result> { + let (user_id_to_id, id_to_user_id, id_to_data) = match segment.file_path.len() { + 4 => { + let user_id_to_id_bf_id = &segment.file_path.get(USER_ID_TO_OFFSET_ID).unwrap()[0]; + let id_to_user_id_bf_id = &segment.file_path.get(OFFSET_ID_TO_USER_ID).unwrap()[0]; + let id_to_data_bf_id = &segment.file_path.get(OFFSET_ID_TO_DATA).unwrap()[0]; + + let user_id_to_id = match blockfile_provider + .open::<&str, u32>(&Uuid::parse_str(user_id_to_id_bf_id).unwrap()) + .await + { + Ok(user_id_to_id) => user_id_to_id, + Err(e) => { + return Err(Box::new( + RecordSegmentReaderCreationError::BlockfileOpenError(e), + )) + } + }; + + let id_to_user_id = match blockfile_provider + .open::(&Uuid::parse_str(id_to_user_id_bf_id).unwrap()) + .await + { + Ok(id_to_user_id) => id_to_user_id, + Err(e) => { + return Err(Box::new( + RecordSegmentReaderCreationError::BlockfileOpenError(e), + )) + } + }; + + let id_to_data = match blockfile_provider + .open::(&Uuid::parse_str(id_to_data_bf_id).unwrap()) + .await + { + Ok(id_to_data) => id_to_data, + Err(e) => { + return Err(Box::new( + RecordSegmentReaderCreationError::BlockfileOpenError(e), + )) + } + }; + + (user_id_to_id, id_to_user_id, id_to_data) + } + 0 => { + return Err(Box::new( + RecordSegmentReaderCreationError::UninitializedSegment, + )); + } + _ => { + return Err(Box::new( + RecordSegmentReaderCreationError::InvalidNumberOfFiles, + )); + } + }; + + Ok(RecordSegmentReader { + user_id_to_id, + id_to_user_id, + id_to_data, + }) + } + + pub(crate) async fn get_user_id_for_offset_id( + &self, + offset_id: u32, + ) -> Result<&str, Box> { + self.id_to_user_id.get("", offset_id).await + } +} diff --git a/rust/worker/src/server.rs b/rust/worker/src/server.rs index d9a821af56f..ac33793edbd 100644 --- a/rust/worker/src/server.rs +++ b/rust/worker/src/server.rs @@ -1,3 +1,6 @@ +use std::path::PathBuf; + +use crate::blockstore::provider::BlockfileProvider; use crate::chroma_proto; use crate::chroma_proto::{ GetVectorsRequest, GetVectorsResponse, QueryVectorsRequest, QueryVectorsResponse, @@ -6,6 +9,7 @@ use crate::config::{Configurable, QueryServiceConfig}; use crate::errors::ChromaError; use crate::execution::operator::TaskMessage; use crate::execution::orchestration::HnswQueryOrchestrator; +use crate::index::hnsw_provider::HnswIndexProvider; use crate::log::log::Log; use crate::sysdb::sysdb::SysDb; use crate::system::{Receiver, System}; @@ -23,6 +27,8 @@ pub struct WorkerServer { // Service dependencies log: Box, sysdb: Box, + hnsw_index_provider: HnswIndexProvider, + blockfile_provider: BlockfileProvider, port: u16, } @@ -33,6 +39,7 @@ impl Configurable for WorkerServer { let sysdb = match crate::sysdb::from_config(sysdb_config).await { Ok(sysdb) => sysdb, Err(err) => { + println!("Failed to create sysdb component: {:?}", err); return Err(err); } }; @@ -40,14 +47,28 @@ impl Configurable for WorkerServer { let log = match crate::log::from_config(log_config).await { Ok(log) => log, Err(err) => { + println!("Failed to create log component: {:?}", err); + return Err(err); + } + }; + let storage = match crate::storage::from_config(&config.storage).await { + Ok(storage) => storage, + Err(err) => { + println!("Failed to create storage component: {:?}", err); return Err(err); } }; + // TODO: inject hnsw index provider somehow + // TODO: inject blockfile provider somehow + // TODO: real path + let path = PathBuf::from("~/tmp"); Ok(WorkerServer { dispatcher: None, system: None, sysdb, log, + hnsw_index_provider: HnswIndexProvider::new(storage.clone(), path), + blockfile_provider: BlockfileProvider::new_arrow(storage), port: config.my_port, }) } @@ -142,6 +163,8 @@ impl chroma_proto::vector_reader_server::VectorReader for WorkerServer { segment_uuid, self.log.clone(), self.sysdb.clone(), + self.hnsw_index_provider.clone(), + self.blockfile_provider.clone(), dispatcher.clone(), ); orchestrator.run().await diff --git a/rust/worker/src/storage/s3.rs b/rust/worker/src/storage/s3.rs index 76b70af3142..5ab330c2989 100644 --- a/rust/worker/src/storage/s3.rs +++ b/rust/worker/src/storage/s3.rs @@ -179,7 +179,6 @@ impl ChromaError for StorageConfigError { #[async_trait] impl Configurable for S3Storage { async fn try_from_config(config: &StorageConfig) -> Result> { - println!("Creating storage with config: {:?}", config); match &config { StorageConfig::S3(s3_config) => { let client = match &s3_config.credentials { diff --git a/rust/worker/src/types/segment.rs b/rust/worker/src/types/segment.rs index 0027d65d09a..5cfd8875596 100644 --- a/rust/worker/src/types/segment.rs +++ b/rust/worker/src/types/segment.rs @@ -14,6 +14,18 @@ pub(crate) enum SegmentType { Sqlite, } +impl From for String { + fn from(segment_type: SegmentType) -> String { + match segment_type { + SegmentType::HnswDistributed => { + "urn:chroma:segment/vector/hnsw-distributed".to_string() + } + SegmentType::Record => "urn:chroma:segment/record".to_string(), + SegmentType::Sqlite => "urn:chroma:segment/metadata/sqlite".to_string(), + } + } +} + #[derive(Clone, Debug, PartialEq)] pub(crate) struct Segment { pub(crate) id: Uuid,