diff --git a/rust/blockstore/src/arrow/block/delta/data_record.rs b/rust/blockstore/src/arrow/block/delta/data_record.rs index 6a4e3ac8ad5..1c5ef529ebd 100644 --- a/rust/blockstore/src/arrow/block/delta/data_record.rs +++ b/rust/blockstore/src/arrow/block/delta/data_record.rs @@ -5,8 +5,8 @@ use crate::{ }; use arrow::{ array::{ - Array, ArrayRef, BinaryBuilder, FixedSizeListBuilder, Float32Builder, StringBuilder, - StructArray, + Array, ArrayRef, BinaryBuilder, FixedSizeListBuilder, Float32Builder, RecordBatch, + StringBuilder, StructArray, }, datatypes::{Field, Fields}, util::bit_util, @@ -340,57 +340,63 @@ impl DataRecordStorage { inner.storage.len() } - pub(super) fn build_keys(&self, builder: BlockKeyArrowBuilder) -> BlockKeyArrowBuilder { - let inner = self.inner.read(); - let mut builder = builder; - for (key, _) in inner.storage.iter() { - builder.add_key(key.clone()); - } - builder - } - - pub(super) fn to_arrow(&self) -> (Field, ArrayRef) { - let inner = self.inner.read(); - - let item_capacity = inner.storage.len(); + pub(super) fn to_arrow( + self, + key_builder: BlockKeyArrowBuilder, + ) -> Result { + // build arrow key. + let mut key_builder = key_builder; let mut embedding_builder; let mut id_builder; let mut metadata_builder; let mut document_builder; let embedding_dim; - if item_capacity == 0 { - // ok to initialize fixed size float list with fixed size as 0. - embedding_dim = 0; - embedding_builder = FixedSizeListBuilder::new(Float32Builder::new(), 0); - id_builder = StringBuilder::new(); - metadata_builder = BinaryBuilder::new(); - document_builder = StringBuilder::new(); - } else { - embedding_dim = inner.storage.iter().next().unwrap().1 .1.len(); - // Assumes all embeddings are of the same length, which is guaranteed by calling code - // TODO: validate this assumption by throwing an error if it's not true - let total_embedding_count = embedding_dim * item_capacity; - id_builder = StringBuilder::with_capacity(item_capacity, inner.id_size); - embedding_builder = FixedSizeListBuilder::with_capacity( - Float32Builder::with_capacity(total_embedding_count), - embedding_dim as i32, - item_capacity, - ); - metadata_builder = BinaryBuilder::with_capacity(item_capacity, inner.metadata_size); - document_builder = StringBuilder::with_capacity(item_capacity, inner.document_size); - } - - let iter = inner.storage.iter(); - for (_key, (id, embedding, metadata, document)) in iter { - id_builder.append_value(id); - let embedding_arr = embedding_builder.values(); - for entry in embedding.iter() { - embedding_arr.append_value(*entry); + match Arc::try_unwrap(self.inner) { + Ok(inner) => { + let inner = inner.into_inner(); + let storage = inner.storage; + let item_capacity = storage.len(); + if item_capacity == 0 { + // ok to initialize fixed size float list with fixed size as 0. + embedding_dim = 0; + embedding_builder = FixedSizeListBuilder::new(Float32Builder::new(), 0); + id_builder = StringBuilder::new(); + metadata_builder = BinaryBuilder::new(); + document_builder = StringBuilder::new(); + } else { + embedding_dim = storage.iter().next().unwrap().1 .1.len(); + // Assumes all embeddings are of the same length, which is guaranteed by calling code + // TODO: validate this assumption by throwing an error if it's not true + let total_embedding_count = embedding_dim * item_capacity; + id_builder = StringBuilder::with_capacity(item_capacity, inner.id_size); + embedding_builder = FixedSizeListBuilder::with_capacity( + Float32Builder::with_capacity(total_embedding_count), + embedding_dim as i32, + item_capacity, + ); + metadata_builder = + BinaryBuilder::with_capacity(item_capacity, inner.metadata_size); + document_builder = + StringBuilder::with_capacity(item_capacity, inner.document_size); + } + for (key, (id, embedding, metadata, document)) in storage.into_iter() { + key_builder.add_key(key); + id_builder.append_value(id); + let embedding_arr = embedding_builder.values(); + for entry in embedding { + embedding_arr.append_value(entry); + } + embedding_builder.append(true); + metadata_builder.append_option(metadata.as_deref()); + document_builder.append_option(document.as_deref()); + } + } + Err(_) => { + panic!("Invariant violation: SingleColumnStorage inner should have only one reference."); } - embedding_builder.append(true); - metadata_builder.append_option(metadata.as_deref()); - document_builder.append_option(document.as_deref()); } + // Build arrow key with fields. + let (prefix_field, prefix_arr, key_field, key_arr) = key_builder.to_arrow(); let id_field = Field::new("id", arrow::datatypes::DataType::Utf8, true); let embedding_field = Field::new( @@ -439,9 +445,12 @@ impl DataRecordStorage { arrow::datatypes::DataType::Struct(struct_fields), true, ); - ( + let value_arr = (&struct_arr as &dyn Array).slice(0, struct_arr.len()); + let schema = Arc::new(arrow::datatypes::Schema::new(vec![ + prefix_field, + key_field, struct_field, - (&struct_arr as &dyn Array).slice(0, struct_arr.len()), - ) + ])); + RecordBatch::try_new(schema, vec![prefix_arr, key_arr, value_arr]) } } diff --git a/rust/blockstore/src/arrow/block/delta/delta.rs b/rust/blockstore/src/arrow/block/delta/delta.rs index 52737e3b12b..e21da3f5732 100644 --- a/rust/blockstore/src/arrow/block/delta/delta.rs +++ b/rust/blockstore/src/arrow/block/delta/delta.rs @@ -60,7 +60,7 @@ impl BlockDelta { self.builder.get_size::() } - pub fn finish(&self) -> RecordBatch { + pub fn finish(self) -> RecordBatch { self.builder.to_record_batch::() } @@ -121,7 +121,6 @@ impl BlockDelta { #[cfg(test)] mod test { use crate::arrow::{block::Block, config::TEST_MAX_BLOCK_SIZE_BYTES, provider::BlockManager}; - use arrow::array::Int32Array; use chroma_cache::{ cache::Cache, config::{CacheConfig, UnboundedCacheConfig}, @@ -154,7 +153,7 @@ mod test { let storage = Storage::Local(LocalStorage::new(path)); let cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {})); let block_manager = BlockManager::new(storage, TEST_MAX_BLOCK_SIZE_BYTES, cache); - let delta = block_manager.create::<&str, Int32Array>(); + let delta = block_manager.create::<&str, Vec>(); let n = 2000; for i in 0..n { @@ -163,27 +162,25 @@ mod test { let value_len: usize = rand::thread_rng().gen_range(1..100); let mut new_vec = Vec::with_capacity(value_len); for _ in 0..value_len { - new_vec.push(random::()); + new_vec.push(random::()); } - delta.add::<&str, Int32Array>(prefix, &key, Int32Array::from(new_vec)); + delta.add::<&str, Vec>(prefix, &key, new_vec); } - let size = delta.get_size::<&str, Int32Array>(); - // TODO: should commit take ownership of delta? - // Semantically, that makes sense, since a delta is unsuable after commit + let size = delta.get_size::<&str, Vec>(); - let block = block_manager.commit::<&str, Int32Array>(&delta); + let block = block_manager.commit::<&str, Vec>(delta); let mut values_before_flush = vec![]; for i in 0..n { let key = format!("key{}", i); - let read = block.get::<&str, Int32Array>("prefix", &key).unwrap(); - values_before_flush.push(read); + let read = block.get::<&str, &[u32]>("prefix", &key).unwrap(); + values_before_flush.push(read.to_vec()); } block_manager.flush(&block).await.unwrap(); let block = block_manager.get(&block.clone().id).await.unwrap(); for i in 0..n { let key = format!("key{}", i); - let read = block.get::<&str, Int32Array>("prefix", &key).unwrap(); + let read = block.get::<&str, &[u32]>("prefix", &key).unwrap(); assert_eq!(read, values_before_flush[i]); } test_save_load_size(path, &block); @@ -208,7 +205,7 @@ mod test { delta.add(prefix, key.as_str(), value.to_owned()); } let size = delta.get_size::<&str, String>(); - let block = block_manager.commit::<&str, String>(&delta); + let block = block_manager.commit::<&str, String>(delta); let mut values_before_flush = vec![]; for i in 0..n { let key = format!("key{}", i); @@ -237,7 +234,7 @@ mod test { // test fork let forked_block = block_manager.fork::<&str, String>(&delta_id).await; let new_id = forked_block.id.clone(); - let block = block_manager.commit::<&str, String>(&forked_block); + let block = block_manager.commit::<&str, String>(forked_block); block_manager.flush(&block).await.unwrap(); let forked_block = block_manager.get(&new_id).await.unwrap(); for i in 0..n { @@ -265,7 +262,8 @@ mod test { } let size = delta.get_size::(); - let block = block_manager.commit::(&delta); + let delta_id = delta.id.clone(); + let block = block_manager.commit::(delta); let mut values_before_flush = vec![]; for i in 0..n { let key = i as f32; @@ -273,7 +271,7 @@ mod test { values_before_flush.push(read); } block_manager.flush(&block).await.unwrap(); - let block = block_manager.get(&delta.id).await.unwrap(); + let block = block_manager.get(&delta_id).await.unwrap(); assert_eq!(size, block.get_size()); for i in 0..n { let key = i as f32; @@ -302,9 +300,10 @@ mod test { } let size = delta.get_size::<&str, RoaringBitmap>(); - let block = block_manager.commit::<&str, RoaringBitmap>(&delta); + let delta_id = delta.id.clone(); + let block = block_manager.commit::<&str, RoaringBitmap>(delta); block_manager.flush(&block).await.unwrap(); - let block = block_manager.get(&delta.id).await.unwrap(); + let block = block_manager.get(&delta_id).await.unwrap(); assert_eq!(size, block.get_size()); @@ -366,9 +365,10 @@ mod test { } let size = delta.get_size::<&str, &DataRecord>(); - let block = block_manager.commit::<&str, &DataRecord>(&delta); + let delta_id = delta.id.clone(); + let block = block_manager.commit::<&str, &DataRecord>(delta); block_manager.flush(&block).await.unwrap(); - let block = block_manager.get(&delta.id).await.unwrap(); + let block = block_manager.get(&delta_id).await.unwrap(); for i in 0..3 { let read = block.get::<&str, DataRecord>("", ids[i]).unwrap(); assert_eq!(read.id, ids[i]); @@ -400,9 +400,10 @@ mod test { } let size = delta.get_size::(); - let block = block_manager.commit::(&delta); + let delta_id = delta.id.clone(); + let block = block_manager.commit::(delta); block_manager.flush(&block).await.unwrap(); - let block = block_manager.get(&delta.id).await.unwrap(); + let block = block_manager.get(&delta_id).await.unwrap(); assert_eq!(size, block.get_size()); // test save/load @@ -427,7 +428,7 @@ mod test { delta.add(prefix, key, value); } let size = delta.get_size::(); - let block = block_manager.commit::(&delta); + let block = block_manager.commit::(delta); let mut values_before_flush = vec![]; for i in 0..n { let key = i as u32; @@ -456,7 +457,7 @@ mod test { // test fork let forked_block = block_manager.fork::(&delta_id).await; let new_id = forked_block.id.clone(); - let block = block_manager.commit::(&forked_block); + let block = block_manager.commit::(forked_block); block_manager.flush(&block).await.unwrap(); let forked_block = block_manager.get(&new_id).await.unwrap(); for i in 0..n { diff --git a/rust/blockstore/src/arrow/block/delta/single_column_storage.rs b/rust/blockstore/src/arrow/block/delta/single_column_storage.rs index 6acdad5d35d..82cd4ca628b 100644 --- a/rust/blockstore/src/arrow/block/delta/single_column_storage.rs +++ b/rust/blockstore/src/arrow/block/delta/single_column_storage.rs @@ -6,8 +6,7 @@ use crate::{ }; use arrow::{ array::{ - Array, ArrayRef, BinaryBuilder, Int32Array, Int32Builder, ListBuilder, StringBuilder, - UInt32Builder, + Array, BinaryBuilder, ListBuilder, RecordBatch, StringBuilder, UInt32Array, UInt32Builder, }, datatypes::Field, util::bit_util, @@ -86,16 +85,6 @@ impl SingleColumnStorage { + value_validity_bytes } - pub(super) fn build_keys(&self, builder: BlockKeyArrowBuilder) -> BlockKeyArrowBuilder { - let inner = self.inner.read(); - let storage = &inner.storage; - let mut builder = builder; - for (key, _) in storage.iter() { - builder.add_key(key.clone()); - } - builder - } - pub fn add(&self, prefix: &str, key: KeyWrapper, value: T) { let mut inner = self.inner.write(); let key_len = key.get_size(); @@ -113,11 +102,12 @@ impl SingleColumnStorage { inner.size_tracker.subtract_key_size(key_len); inner.size_tracker.subtract_prefix_size(prefix.len()); } + let value_size = value.get_size(); - inner.storage.insert(composite_key, value.to_owned()); + inner.storage.insert(composite_key, value); inner.size_tracker.add_prefix_size(prefix.len()); inner.size_tracker.add_key_size(key_len); - inner.size_tracker.add_value_size(value.get_size()); + inner.size_tracker.add_value_size(value_size); } pub fn delete(&self, prefix: &str, key: KeyWrapper) { @@ -229,7 +219,12 @@ impl SingleColumnStorage { } impl SingleColumnStorage { - pub(super) fn to_arrow(&self) -> (Field, ArrayRef) { + pub(super) fn to_arrow( + self, + key_builder: BlockKeyArrowBuilder, + ) -> Result { + // Build key and value. + let mut key_builder = key_builder; let item_capacity = self.len(); let mut value_builder; if item_capacity == 0 { @@ -237,115 +232,173 @@ impl SingleColumnStorage { } else { value_builder = StringBuilder::with_capacity(item_capacity, self.get_value_size()); } - - let inner = self.inner.read(); - let storage = &inner.storage; - - for (_, value) in storage.iter() { - value_builder.append_value(value); + match Arc::try_unwrap(self.inner) { + Ok(inner) => { + let storage = inner.into_inner().storage; + for (key, value) in storage.into_iter() { + key_builder.add_key(key); + value_builder.append_value(value); + } + } + Err(_) => { + panic!("Invariant violation: SingleColumnStorage inner should have only one reference."); + } } - + // Build arrow key with fields. + let (prefix_field, prefix_arr, key_field, key_arr) = key_builder.to_arrow(); + // Build arrow value with fields. let value_field = Field::new("value", arrow::datatypes::DataType::Utf8, false); let value_arr = value_builder.finish(); - ( + let value_arr = (&value_arr as &dyn Array).slice(0, value_arr.len()); + let schema = Arc::new(arrow::datatypes::Schema::new(vec![ + prefix_field, + key_field, value_field, - (&value_arr as &dyn Array).slice(0, value_arr.len()), - ) + ])); + RecordBatch::try_new(schema, vec![prefix_arr, key_arr, value_arr]) } } -impl SingleColumnStorage { - pub(super) fn to_arrow(&self) -> (Field, ArrayRef) { +impl SingleColumnStorage> { + pub(super) fn to_arrow( + self, + key_builder: BlockKeyArrowBuilder, + ) -> Result { + // Build key and value. + let mut key_builder = key_builder; let item_capacity = self.len(); - let inner = self.inner.read(); - let storage = &inner.storage; - let total_value_count = storage.iter().fold(0, |acc, (_, value)| acc + value.len()); - let mut value_builder; - if item_capacity == 0 { - value_builder = ListBuilder::new(Int32Builder::new()); - } else { - value_builder = ListBuilder::with_capacity( - Int32Builder::with_capacity(total_value_count), - item_capacity, - ); - } - - for (_, value) in storage.iter() { - value_builder.append_value(value); + match Arc::try_unwrap(self.inner) { + Ok(inner) => { + let storage = inner.into_inner().storage; + let total_value_count = storage.iter().fold(0, |acc, (_, value)| acc + value.len()); + if item_capacity == 0 { + value_builder = ListBuilder::new(UInt32Builder::new()); + } else { + value_builder = ListBuilder::with_capacity( + UInt32Builder::with_capacity(total_value_count), + item_capacity, + ); + } + for (key, value) in storage.into_iter() { + key_builder.add_key(key); + value_builder.append_value(&UInt32Array::from(value)); + } + } + Err(_) => { + panic!("Invariant violation: SingleColumnStorage inner should have only one reference."); + } } + // Build arrow key and value with fields. + let (prefix_field, prefix_arr, key_field, key_arr) = key_builder.to_arrow(); let value_field = Field::new( "value", arrow::datatypes::DataType::List(Arc::new(Field::new( "item", - arrow::datatypes::DataType::Int32, + arrow::datatypes::DataType::UInt32, true, ))), true, ); let value_arr = value_builder.finish(); - ( + let value_arr = (&value_arr as &dyn Array).slice(0, value_arr.len()); + let schema = Arc::new(arrow::datatypes::Schema::new(vec![ + prefix_field, + key_field, value_field, - (&value_arr as &dyn Array).slice(0, value_arr.len()), - ) + ])); + RecordBatch::try_new(schema, vec![prefix_arr, key_arr, value_arr]) } } impl SingleColumnStorage { - pub(super) fn to_arrow(&self) -> (Field, ArrayRef) { - let inner = self.inner.read(); - let storage = &inner.storage; - let item_capacity = storage.len(); + pub(super) fn to_arrow( + self, + key_builder: BlockKeyArrowBuilder, + ) -> Result { + // Build key and value. + let mut key_builder = key_builder; let mut value_builder; - if item_capacity == 0 { - value_builder = UInt32Builder::new(); - } else { - value_builder = UInt32Builder::with_capacity(item_capacity); - } - for (_, value) in storage.iter() { - value_builder.append_value(*value); + match Arc::try_unwrap(self.inner) { + Ok(inner) => { + let storage = inner.into_inner().storage; + let item_capacity = storage.len(); + if item_capacity == 0 { + value_builder = UInt32Builder::new(); + } else { + value_builder = UInt32Builder::with_capacity(item_capacity); + } + for (key, value) in storage.into_iter() { + key_builder.add_key(key); + value_builder.append_value(value); + } + } + Err(_) => { + panic!("Invariant violation: SingleColumnStorage inner should have only one reference."); + } } + // Build arrow key with fields. + let (prefix_field, prefix_arr, key_field, key_arr) = key_builder.to_arrow(); let value_field = Field::new("value", arrow::datatypes::DataType::UInt32, false); let value_arr = value_builder.finish(); - ( + let value_arr = (&value_arr as &dyn Array).slice(0, value_arr.len()); + let schema = Arc::new(arrow::datatypes::Schema::new(vec![ + prefix_field, + key_field, value_field, - (&value_arr as &dyn Array).slice(0, value_arr.len()), - ) + ])); + RecordBatch::try_new(schema, vec![prefix_arr, key_arr, value_arr]) } } impl SingleColumnStorage { - pub(super) fn to_arrow(&self) -> (Field, ArrayRef) { - let inner = self.inner.read(); - let storage = &inner.storage; + pub(super) fn to_arrow( + self, + key_builder: BlockKeyArrowBuilder, + ) -> Result { + // Build key. + let mut key_builder = key_builder; let item_capacity = self.len(); - let total_value_count = storage - .iter() - .fold(0, |acc, (_, value)| acc + value.get_size()); let mut value_builder; - if item_capacity == 0 { - value_builder = BinaryBuilder::new(); - } else { - value_builder = BinaryBuilder::with_capacity(item_capacity, total_value_count); - } - - for (_, value) in storage.iter() { - let mut serialized = Vec::with_capacity(value.serialized_size()); - let res = value.serialize_into(&mut serialized); - // TODO: proper error handling - let serialized = match res { - Ok(_) => serialized, - Err(e) => panic!("Failed to serialize RoaringBitmap: {}", e), - }; - value_builder.append_value(serialized); + match Arc::try_unwrap(self.inner) { + Ok(inner) => { + let storage = inner.into_inner().storage; + let total_value_count = storage + .iter() + .fold(0, |acc, (_, value)| acc + value.get_size()); + if item_capacity == 0 { + value_builder = BinaryBuilder::new(); + } else { + value_builder = BinaryBuilder::with_capacity(item_capacity, total_value_count); + } + for (key, value) in storage.into_iter() { + key_builder.add_key(key); + let mut serialized = Vec::with_capacity(value.serialized_size()); + let res = value.serialize_into(&mut serialized); + // TODO: proper error handling + let serialized = match res { + Ok(_) => serialized, + Err(e) => panic!("Failed to serialize RoaringBitmap: {}", e), + }; + value_builder.append_value(serialized); + } + } + Err(_) => { + panic!("Invariant violation: SingleColumnStorage inner should have only one reference."); + } } + // Build arrow key with fields. + let (prefix_field, prefix_arr, key_field, key_arr) = key_builder.to_arrow(); let value_field = Field::new("value", arrow::datatypes::DataType::Binary, true); let value_arr = value_builder.finish(); - ( + let value_arr = (&value_arr as &dyn Array).slice(0, value_arr.len()); + let schema = Arc::new(arrow::datatypes::Schema::new(vec![ + prefix_field, + key_field, value_field, - (&value_arr as &dyn Array).slice(0, value_arr.len()), - ) + ])); + RecordBatch::try_new(schema, vec![prefix_arr, key_arr, value_arr]) } } diff --git a/rust/blockstore/src/arrow/block/delta/storage.rs b/rust/blockstore/src/arrow/block/delta/storage.rs index addb7ec7977..4ab92d36229 100644 --- a/rust/blockstore/src/arrow/block/delta/storage.rs +++ b/rust/blockstore/src/arrow/block/delta/storage.rs @@ -5,8 +5,7 @@ use crate::{ }; use arrow::{ array::{ - Array, ArrayRef, BooleanBuilder, Float32Builder, Int32Array, RecordBatch, StringBuilder, - UInt32Builder, + Array, ArrayRef, BooleanBuilder, Float32Builder, RecordBatch, StringBuilder, UInt32Builder, }, datatypes::Field, }; @@ -14,13 +13,12 @@ use roaring::RoaringBitmap; use std::{ fmt, fmt::{Debug, Formatter}, - sync::Arc, }; #[derive(Clone)] pub enum BlockStorage { String(SingleColumnStorage), - Int32Array(SingleColumnStorage), + VecUInt32(SingleColumnStorage>), UInt32(SingleColumnStorage), RoaringBitmap(SingleColumnStorage), DataRecord(DataRecordStorage), @@ -30,7 +28,7 @@ impl Debug for BlockStorage { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { match self { BlockStorage::String(_) => write!(f, "String"), - BlockStorage::Int32Array(_) => write!(f, "Int32Array"), + BlockStorage::VecUInt32(_) => write!(f, "VecUInt32"), BlockStorage::UInt32(_) => write!(f, "UInt32"), BlockStorage::RoaringBitmap(_) => write!(f, "RoaringBitmap"), BlockStorage::DataRecord(_) => write!(f, "DataRecord"), @@ -91,7 +89,7 @@ impl BlockKeyArrowBuilder { } } - fn to_arrow(&mut self) -> (Field, ArrayRef, Field, ArrayRef) { + pub fn to_arrow(&mut self) -> (Field, ArrayRef, Field, ArrayRef) { match self { BlockKeyArrowBuilder::String((ref mut prefix_builder, ref mut key_builder)) => { let prefix_field = Field::new("prefix", arrow::datatypes::DataType::Utf8, false); @@ -151,7 +149,7 @@ impl BlockStorage { BlockStorage::String(builder) => builder.get_prefix_size(), BlockStorage::UInt32(builder) => builder.get_prefix_size(), BlockStorage::DataRecord(builder) => builder.get_prefix_size(), - BlockStorage::Int32Array(builder) => builder.get_prefix_size(), + BlockStorage::VecUInt32(builder) => builder.get_prefix_size(), BlockStorage::RoaringBitmap(builder) => builder.get_prefix_size(), } } @@ -161,7 +159,7 @@ impl BlockStorage { BlockStorage::String(builder) => builder.get_key_size(), BlockStorage::UInt32(builder) => builder.get_key_size(), BlockStorage::DataRecord(builder) => builder.get_key_size(), - BlockStorage::Int32Array(builder) => builder.get_key_size(), + BlockStorage::VecUInt32(builder) => builder.get_key_size(), BlockStorage::RoaringBitmap(builder) => builder.get_key_size(), } } @@ -171,7 +169,7 @@ impl BlockStorage { BlockStorage::String(builder) => builder.get_min_key(), BlockStorage::UInt32(builder) => builder.get_min_key(), BlockStorage::DataRecord(builder) => builder.get_min_key(), - BlockStorage::Int32Array(builder) => builder.get_min_key(), + BlockStorage::VecUInt32(builder) => builder.get_min_key(), BlockStorage::RoaringBitmap(builder) => builder.get_min_key(), } } @@ -182,7 +180,7 @@ impl BlockStorage { BlockStorage::String(builder) => builder.get_size::(), BlockStorage::UInt32(builder) => builder.get_size::(), BlockStorage::DataRecord(builder) => builder.get_size::(), - BlockStorage::Int32Array(builder) => builder.get_size::(), + BlockStorage::VecUInt32(builder) => builder.get_size::(), BlockStorage::RoaringBitmap(builder) => builder.get_size::(), } } @@ -201,9 +199,9 @@ impl BlockStorage { let (split_key, storage) = builder.split::(split_size); (split_key, BlockStorage::DataRecord(storage)) } - BlockStorage::Int32Array(builder) => { + BlockStorage::VecUInt32(builder) => { let (split_key, storage) = builder.split::(split_size); - (split_key, BlockStorage::Int32Array(storage)) + (split_key, BlockStorage::VecUInt32(storage)) } BlockStorage::RoaringBitmap(builder) => { let (split_key, storage) = builder.split::(split_size); @@ -217,47 +215,35 @@ impl BlockStorage { BlockStorage::String(builder) => builder.len(), BlockStorage::UInt32(builder) => builder.len(), BlockStorage::DataRecord(builder) => builder.len(), - BlockStorage::Int32Array(builder) => builder.len(), + BlockStorage::VecUInt32(builder) => builder.len(), BlockStorage::RoaringBitmap(builder) => builder.len(), } } - pub fn to_record_batch(&self) -> RecordBatch { - let mut key_builder = + pub fn to_record_batch(self) -> RecordBatch { + let key_builder = K::get_arrow_builder(self.len(), self.get_prefix_size(), self.get_key_size()); match self { BlockStorage::String(builder) => { - key_builder = builder.build_keys(key_builder); + // TODO: handle error + builder.to_arrow(key_builder).unwrap() } BlockStorage::UInt32(builder) => { - key_builder = builder.build_keys(key_builder); + // TODO: handle error + builder.to_arrow(key_builder).unwrap() } BlockStorage::DataRecord(builder) => { - key_builder = builder.build_keys(key_builder); + // TODO: handle error + builder.to_arrow(key_builder).unwrap() } - BlockStorage::Int32Array(builder) => { - key_builder = builder.build_keys(key_builder); + BlockStorage::VecUInt32(builder) => { + // TODO: handle error + builder.to_arrow(key_builder).unwrap() } BlockStorage::RoaringBitmap(builder) => { - key_builder = builder.build_keys(key_builder); + // TODO: handle error + builder.to_arrow(key_builder).unwrap() } } - - let (prefix_field, prefix_arr, key_field, key_arr) = key_builder.to_arrow(); - let (value_field, value_arr) = match self { - BlockStorage::String(builder) => builder.to_arrow(), - BlockStorage::UInt32(builder) => builder.to_arrow(), - BlockStorage::DataRecord(builder) => builder.to_arrow(), - BlockStorage::Int32Array(builder) => builder.to_arrow(), - BlockStorage::RoaringBitmap(builder) => builder.to_arrow(), - }; - let schema = Arc::new(arrow::datatypes::Schema::new(vec![ - prefix_field, - key_field, - value_field, - ])); - let record_batch = RecordBatch::try_new(schema, vec![prefix_arr, key_arr, value_arr]); - // TODO: handle error - record_batch.unwrap() } } diff --git a/rust/blockstore/src/arrow/block/value/int32array_value.rs b/rust/blockstore/src/arrow/block/value/int32array_value.rs index 0d91ded22c9..da3de68f33f 100644 --- a/rust/blockstore/src/arrow/block/value/int32array_value.rs +++ b/rust/blockstore/src/arrow/block/value/int32array_value.rs @@ -6,16 +6,16 @@ use crate::{ key::KeyWrapper, }; use arrow::{ - array::{Array, Int32Array, ListArray}, + array::{Array, ListArray, UInt32Array}, util::bit_util, }; -use std::sync::Arc; +use std::{mem::size_of, sync::Arc}; -impl ArrowWriteableValue for Int32Array { - type ReadableValue<'referred_data> = Int32Array; +impl ArrowWriteableValue for Vec { + type ReadableValue<'referred_data> = &'referred_data [u32]; fn offset_size(item_count: usize) -> usize { - bit_util::round_upto_multiple_of_64((item_count + 1) * 4) + bit_util::round_upto_multiple_of_64((item_count + 1) * size_of::()) } fn validity_size(_item_count: usize) -> usize { @@ -24,14 +24,8 @@ impl ArrowWriteableValue for Int32Array { fn add(prefix: &str, key: KeyWrapper, value: Self, delta: &BlockDelta) { match &delta.builder { - BlockStorage::Int32Array(builder) => { - // We have to clone the value in this odd way here because when reading out of a block we get the entire array - let mut new_vec = Vec::with_capacity(value.len()); - for i in 0..value.len() { - new_vec.push(value.value(i)); - } - let new_arr = Int32Array::from(new_vec); - builder.add(prefix, key, new_arr); + BlockStorage::VecUInt32(builder) => { + builder.add(prefix, key, value); } _ => panic!("Invalid builder type"), } @@ -39,7 +33,7 @@ impl ArrowWriteableValue for Int32Array { fn delete(prefix: &str, key: KeyWrapper, delta: &BlockDelta) { match &delta.builder { - BlockStorage::Int32Array(builder) => { + BlockStorage::VecUInt32(builder) => { builder.delete(prefix, key); } _ => panic!("Invalid builder type"), @@ -47,19 +41,21 @@ impl ArrowWriteableValue for Int32Array { } fn get_delta_builder() -> BlockStorage { - BlockStorage::Int32Array(SingleColumnStorage::new()) + BlockStorage::VecUInt32(SingleColumnStorage::new()) } } -impl ArrowReadableValue<'_> for Int32Array { - fn get(array: &Arc, index: usize) -> Self { - let arr = array +impl<'referred_data> ArrowReadableValue<'referred_data> for &'referred_data [u32] { + fn get(array: &'referred_data Arc, index: usize) -> Self { + let list_array = array.as_any().downcast_ref::().unwrap(); + let start = list_array.value_offsets()[index] as usize; + let end = list_array.value_offsets()[index + 1] as usize; + let u32array = list_array + .values() .as_any() - .downcast_ref::() - .unwrap() - .value(index); - // Cloning an arrow array is cheap, since they are immutable and backed by Arc'ed data - arr.as_any().downcast_ref::().unwrap().clone() + .downcast_ref::() + .unwrap(); + &u32array.values()[start..end] } fn add_to_delta( @@ -68,6 +64,6 @@ impl ArrowReadableValue<'_> for Int32Array { value: Self, delta: &mut BlockDelta, ) { - delta.add(prefix, key, value.clone()); + delta.add(prefix, key, value.to_vec()); } } diff --git a/rust/blockstore/src/arrow/blockfile.rs b/rust/blockstore/src/arrow/blockfile.rs index 33a2c0f2f5b..ed7c7f6baad 100644 --- a/rust/blockstore/src/arrow/blockfile.rs +++ b/rust/blockstore/src/arrow/blockfile.rs @@ -97,7 +97,7 @@ impl ArrowBlockfileWriter { self, ) -> Result> { let mut blocks = Vec::new(); - for delta in self.block_deltas.lock().values() { + for (_, delta) in self.block_deltas.lock().drain() { let mut removed = false; // Skip empty blocks. Also, remove from sparse index. if delta.len() == 0 { @@ -580,7 +580,6 @@ mod tests { use crate::{ arrow::config::TEST_MAX_BLOCK_SIZE_BYTES, arrow::provider::ArrowBlockfileProvider, }; - use arrow::array::Int32Array; use chroma_cache::{ cache::Cache, config::{CacheConfig, UnboundedCacheConfig}, @@ -605,26 +604,23 @@ mod tests { block_cache, sparse_index_cache, ); - let writer = blockfile_provider.create::<&str, Int32Array>().unwrap(); + let writer = blockfile_provider.create::<&str, Vec>().unwrap(); let id = writer.id(); let prefix_1 = "key"; let key1 = "zzzz"; - let value1 = Int32Array::from(vec![1, 2, 3]); + let value1 = vec![1, 2, 3]; writer.set(prefix_1, key1, value1.clone()).await.unwrap(); let prefix_2 = "key"; let key2 = "aaaa"; - let value2 = Int32Array::from(vec![4, 5, 6]); + let value2 = vec![4, 5, 6]; writer.set(prefix_2, key2, value2).await.unwrap(); - let flusher = writer.commit::<&str, Int32Array>().unwrap(); - flusher.flush::<&str, Int32Array>().await.unwrap(); + let flusher = writer.commit::<&str, Vec>().unwrap(); + flusher.flush::<&str, Vec>().await.unwrap(); - let reader = blockfile_provider - .open::<&str, Int32Array>(&id) - .await - .unwrap(); + let reader = blockfile_provider.open::<&str, &[u32]>(&id).await.unwrap(); let count = reader.count().await; match count { @@ -819,32 +815,29 @@ mod tests { block_cache, sparse_index_cache, ); - let writer = blockfile_provider.create::<&str, Int32Array>().unwrap(); + let writer = blockfile_provider.create::<&str, Vec>().unwrap(); let id = writer.id(); let prefix_1 = "key"; let key1 = "zzzz"; - let value1 = Int32Array::from(vec![1, 2, 3]); + let value1 = vec![1, 2, 3]; writer.set(prefix_1, key1, value1).await.unwrap(); let prefix_2 = "key"; let key2 = "aaaa"; - let value2 = Int32Array::from(vec![4, 5, 6]); + let value2 = vec![4, 5, 6]; writer.set(prefix_2, key2, value2).await.unwrap(); - let flusher = writer.commit::<&str, Int32Array>().unwrap(); - flusher.flush::<&str, Int32Array>().await.unwrap(); + let flusher = writer.commit::<&str, Vec>().unwrap(); + flusher.flush::<&str, Vec>().await.unwrap(); - let reader = blockfile_provider - .open::<&str, Int32Array>(&id) - .await - .unwrap(); + let reader = blockfile_provider.open::<&str, &[u32]>(&id).await.unwrap(); let value = reader.get(prefix_1, key1).await.unwrap(); - assert_eq!(value.values(), &[1, 2, 3]); + assert_eq!(value, [1, 2, 3]); let value = reader.get(prefix_2, key2).await.unwrap(); - assert_eq!(value.values(), &[4, 5, 6]); + assert_eq!(value, [4, 5, 6]); } #[tokio::test] @@ -859,28 +852,28 @@ mod tests { block_cache, sparse_index_cache, ); - let writer = blockfile_provider.create::<&str, Int32Array>().unwrap(); + let writer = blockfile_provider.create::<&str, Vec>().unwrap(); let id_1 = writer.id(); let n = 1200; for i in 0..n { let key = format!("{:04}", i); - let value = Int32Array::from(vec![i]); + let value = vec![i]; writer.set("key", key.as_str(), value).await.unwrap(); } - let flusher = writer.commit::<&str, Int32Array>().unwrap(); - flusher.flush::<&str, Int32Array>().await.unwrap(); + let flusher = writer.commit::<&str, Vec>().unwrap(); + flusher.flush::<&str, Vec>().await.unwrap(); let reader = blockfile_provider - .open::<&str, Int32Array>(&id_1) + .open::<&str, &[u32]>(&id_1) .await .unwrap(); for i in 0..n { let key = format!("{:04}", i); let value = reader.get("key", &key).await.unwrap(); - assert_eq!(value.values(), &[i]); + assert_eq!(value, [i]); } // Sparse index should have 3 blocks @@ -894,28 +887,28 @@ mod tests { // Add 5 new entries to the first block let writer = blockfile_provider - .fork::<&str, Int32Array>(&id_1) + .fork::<&str, Vec>(&id_1) .await .unwrap(); let id_2 = writer.id(); for i in 0..5 { let key = format!("{:05}", i); - let value = Int32Array::from(vec![i]); + let value = vec![i]; writer.set("key", key.as_str(), value).await.unwrap(); } - let flusher = writer.commit::<&str, Int32Array>().unwrap(); - flusher.flush::<&str, Int32Array>().await.unwrap(); + let flusher = writer.commit::<&str, Vec>().unwrap(); + flusher.flush::<&str, Vec>().await.unwrap(); let reader = blockfile_provider - .open::<&str, Int32Array>(&id_2) + .open::<&str, &[u32]>(&id_2) .await .unwrap(); for i in 0..5 { let key = format!("{:05}", i); println!("Getting key: {}", key); let value = reader.get("key", &key).await.unwrap(); - assert_eq!(value.values(), &[i]); + assert_eq!(value, [i]); } // Sparse index should still have 3 blocks @@ -929,26 +922,26 @@ mod tests { // Add 1200 more entries, causing splits let writer = blockfile_provider - .fork::<&str, Int32Array>(&id_2) + .fork::<&str, Vec>(&id_2) .await .unwrap(); let id_3 = writer.id(); for i in n..n * 2 { let key = format!("{:04}", i); - let value = Int32Array::from(vec![i]); + let value = vec![i]; writer.set("key", key.as_str(), value).await.unwrap(); } - let flusher = writer.commit::<&str, Int32Array>().unwrap(); - flusher.flush::<&str, Int32Array>().await.unwrap(); + let flusher = writer.commit::<&str, Vec>().unwrap(); + flusher.flush::<&str, Vec>().await.unwrap(); let reader = blockfile_provider - .open::<&str, Int32Array>(&id_3) + .open::<&str, &[u32]>(&id_3) .await .unwrap(); for i in n..n * 2 { let key = format!("{:04}", i); let value = reader.get("key", &key).await.unwrap(); - assert_eq!(value.values(), &[i]); + assert_eq!(value, [i]); } // Sparse index should have 6 blocks @@ -973,33 +966,33 @@ mod tests { block_cache, sparse_index_cache, ); - let writer = blockfile_provider.create::<&str, Int32Array>().unwrap(); + let writer = blockfile_provider.create::<&str, Vec>().unwrap(); let id_1 = writer.id(); // Add the larger keys first then smaller. let n = 1200; for i in n..n * 2 { let key = format!("{:04}", i); - let value = Int32Array::from(vec![i]); + let value = vec![i]; writer.set("key", key.as_str(), value).await.unwrap(); } for i in 0..n { let key = format!("{:04}", i); - let value = Int32Array::from(vec![i]); + let value = vec![i]; writer.set("key", key.as_str(), value).await.unwrap(); } - let flusher = writer.commit::<&str, Int32Array>().unwrap(); - flusher.flush::<&str, Int32Array>().await.unwrap(); + let flusher = writer.commit::<&str, Vec>().unwrap(); + flusher.flush::<&str, Vec>().await.unwrap(); let reader = blockfile_provider - .open::<&str, Int32Array>(&id_1) + .open::<&str, &[u32]>(&id_1) .await .unwrap(); for i in 0..n * 2 { let key = format!("{:04}", i); let value = reader.get("key", &key).await.unwrap(); - assert_eq!(value.values(), &[i]); + assert_eq!(value, &[i]); } } @@ -1320,26 +1313,26 @@ mod tests { block_cache, sparse_index_cache, ); - let writer = blockfile_provider.create::<&str, Int32Array>().unwrap(); + let writer = blockfile_provider.create::<&str, Vec>().unwrap(); let id_1 = writer.id(); let n = 1200; for i in 0..n { let key = format!("{:04}", i); - let value = Int32Array::from(vec![i]); + let value = vec![i]; writer.set("key", key.as_str(), value).await.unwrap(); } - let flusher = writer.commit::<&str, Int32Array>().unwrap(); - flusher.flush::<&str, Int32Array>().await.unwrap(); + let flusher = writer.commit::<&str, Vec>().unwrap(); + flusher.flush::<&str, Vec>().await.unwrap(); let reader = blockfile_provider - .open::<&str, Int32Array>(&id_1) + .open::<&str, &[u32]>(&id_1) .await .unwrap(); for i in 0..n { let expected_key = format!("{:04}", i); - let expected_value = Int32Array::from(vec![i]); + let expected_value = vec![i]; let res = reader.get_at_index(i as usize).await.unwrap(); assert_eq!(res.0, "key"); assert_eq!(res.1, expected_key); @@ -1359,26 +1352,26 @@ mod tests { block_cache, sparse_index_cache, ); - let writer = blockfile_provider.create::<&str, Int32Array>().unwrap(); + let writer = blockfile_provider.create::<&str, Vec>().unwrap(); let id_1 = writer.id(); // Add the larger keys first then smaller. let n = 1200; for i in n..n * 2 { let key = format!("{:04}", i); - let value = Int32Array::from(vec![i]); + let value = vec![i]; writer.set("key", key.as_str(), value).await.unwrap(); } for i in 0..n { let key = format!("{:04}", i); - let value = Int32Array::from(vec![i]); + let value = vec![i]; writer.set("key", key.as_str(), value).await.unwrap(); } - let flusher = writer.commit::<&str, Int32Array>().unwrap(); - flusher.flush::<&str, Int32Array>().await.unwrap(); + let flusher = writer.commit::<&str, Vec>().unwrap(); + flusher.flush::<&str, Vec>().await.unwrap(); // Create another writer. let writer = blockfile_provider - .fork::<&str, Int32Array>(&id_1) + .fork::<&str, Vec>(&id_1) .await .expect("BlockfileWriter fork unsuccessful"); // Delete everything but the last 10 keys. @@ -1386,16 +1379,16 @@ mod tests { for i in 0..delete_end { let key = format!("{:04}", i); writer - .delete::<&str, Int32Array>("key", key.as_str()) + .delete::<&str, Vec>("key", key.as_str()) .await .expect("Delete failed"); } - let flusher = writer.commit::<&str, Int32Array>().unwrap(); + let flusher = writer.commit::<&str, Vec>().unwrap(); let id_2 = flusher.id(); - flusher.flush::<&str, Int32Array>().await.unwrap(); + flusher.flush::<&str, Vec>().await.unwrap(); let reader = blockfile_provider - .open::<&str, Int32Array>(&id_2) + .open::<&str, &[u32]>(&id_2) .await .unwrap(); @@ -1407,35 +1400,35 @@ mod tests { for i in delete_end..n * 2 { let key = format!("{:04}", i); let value = reader.get("key", &key).await.unwrap(); - assert_eq!(value.values(), &[i]); + assert_eq!(value, [i]); } let writer = blockfile_provider - .fork::<&str, Int32Array>(&id_1) + .fork::<&str, Vec>(&id_1) .await .expect("BlockfileWriter fork unsuccessful"); // Add everything back. for i in 0..delete_end { let key = format!("{:04}", i); - let value = Int32Array::from(vec![i]); + let value = vec![i]; writer - .set::<&str, Int32Array>("key", key.as_str(), value) + .set::<&str, Vec>("key", key.as_str(), value) .await .expect("Delete failed"); } - let flusher = writer.commit::<&str, Int32Array>().unwrap(); + let flusher = writer.commit::<&str, Vec>().unwrap(); let id_3 = flusher.id(); - flusher.flush::<&str, Int32Array>().await.unwrap(); + flusher.flush::<&str, Vec>().await.unwrap(); let reader = blockfile_provider - .open::<&str, Int32Array>(&id_3) + .open::<&str, &[u32]>(&id_3) .await .unwrap(); for i in 0..n * 2 { let key = format!("{:04}", i); let value = reader.get("key", &key).await.unwrap(); - assert_eq!(value.values(), &[i]); + assert_eq!(value, &[i]); } } diff --git a/rust/blockstore/src/arrow/provider.rs b/rust/blockstore/src/arrow/provider.rs index 922433ec47b..d0898c8873f 100644 --- a/rust/blockstore/src/arrow/provider.rs +++ b/rust/blockstore/src/arrow/provider.rs @@ -197,10 +197,11 @@ impl BlockManager { pub(super) fn commit( &self, - delta: &BlockDelta, + delta: BlockDelta, ) -> Block { + let delta_id = delta.id; let record_batch = delta.finish::(); - let block = Block::from_record_batch(delta.id, record_batch); + let block = Block::from_record_batch(delta_id, record_batch); block } diff --git a/rust/blockstore/src/arrow/sparse_index.rs b/rust/blockstore/src/arrow/sparse_index.rs index ea8448efb85..1cdab29fd7d 100644 --- a/rust/blockstore/src/arrow/sparse_index.rs +++ b/rust/blockstore/src/arrow/sparse_index.rs @@ -503,8 +503,9 @@ impl SparseIndex { } } + let delta_id = delta.id; let record_batch = delta.finish::(); - Ok(Block::from_record_batch(delta.id, record_batch)) + Ok(Block::from_record_batch(delta_id, record_batch)) } pub(super) fn from_block<'block, K: ArrowReadableKey<'block> + 'block>( diff --git a/rust/blockstore/src/lib.rs b/rust/blockstore/src/lib.rs index 7bb7abd12b9..69753ed87bb 100644 --- a/rust/blockstore/src/lib.rs +++ b/rust/blockstore/src/lib.rs @@ -1,4 +1,3 @@ -pub mod positional_posting_list_value; pub mod types; pub mod arrow; diff --git a/rust/blockstore/src/memory/storage.rs b/rust/blockstore/src/memory/storage.rs index 6d188a8c555..5a1dc68ac75 100644 --- a/rust/blockstore/src/memory/storage.rs +++ b/rust/blockstore/src/memory/storage.rs @@ -195,10 +195,10 @@ impl<'referred_data> Readable<'referred_data> for &'referred_data str { } // TODO: remove this and make this all use a unified storage so we don't have two impls -impl Writeable for Int32Array { +impl Writeable for Vec { fn write_to_storage(prefix: &str, key: KeyWrapper, value: Self, storage: &StorageBuilder) { storage - .int32_array_storage + .uint32_array_storage .write() .as_mut() .unwrap() @@ -213,7 +213,7 @@ impl Writeable for Int32Array { fn remove_from_storage(prefix: &str, key: KeyWrapper, storage: &StorageBuilder) { storage - .int32_array_storage + .uint32_array_storage .write() .as_mut() .unwrap() @@ -224,15 +224,19 @@ impl Writeable for Int32Array { } } -impl<'referred_data> Readable<'referred_data> for Int32Array { - fn read_from_storage(prefix: &str, key: KeyWrapper, storage: &Storage) -> Option { +impl<'referred_data> Readable<'referred_data> for &'referred_data [u32] { + fn read_from_storage( + prefix: &str, + key: KeyWrapper, + storage: &'referred_data Storage, + ) -> Option { storage - .int32_array_storage + .uint32_array_storage .get(&CompositeKey { prefix: prefix.to_string(), key, }) - .map(|a| a.clone()) + .map(|a| a.as_slice()) } fn get_by_prefix_from_storage( @@ -240,10 +244,10 @@ impl<'referred_data> Readable<'referred_data> for Int32Array { storage: &'referred_data Storage, ) -> Vec<(&'referred_data CompositeKey, Self)> { storage - .int32_array_storage + .uint32_array_storage .iter() .filter(|(k, _)| k.prefix == prefix) - .map(|(k, v)| (k, v.clone())) + .map(|(k, v)| (k, v.as_slice())) .collect() } @@ -253,10 +257,10 @@ impl<'referred_data> Readable<'referred_data> for Int32Array { storage: &'referred_data Storage, ) -> Vec<(&'referred_data CompositeKey, Self)> { storage - .int32_array_storage + .uint32_array_storage .iter() .filter(|(k, _)| k.prefix == prefix && k.key > key) - .map(|(k, v)| (k, v.clone())) + .map(|(k, v)| (k, v.as_slice())) .collect() } @@ -266,10 +270,10 @@ impl<'referred_data> Readable<'referred_data> for Int32Array { storage: &'referred_data Storage, ) -> Vec<(&'referred_data CompositeKey, Self)> { storage - .int32_array_storage + .uint32_array_storage .iter() .filter(|(k, _)| k.prefix == prefix && k.key >= key) - .map(|(k, v)| (k, v.clone())) + .map(|(k, v)| (k, v.as_slice())) .collect() } @@ -279,10 +283,10 @@ impl<'referred_data> Readable<'referred_data> for Int32Array { storage: &'referred_data Storage, ) -> Vec<(&'referred_data CompositeKey, Self)> { storage - .int32_array_storage + .uint32_array_storage .iter() .filter(|(k, _)| k.prefix == prefix && k.key < key) - .map(|(k, v)| (k, v.clone())) + .map(|(k, v)| (k, v.as_slice())) .collect() } @@ -292,10 +296,10 @@ impl<'referred_data> Readable<'referred_data> for Int32Array { storage: &'referred_data Storage, ) -> Vec<(&'referred_data CompositeKey, Self)> { storage - .int32_array_storage + .uint32_array_storage .iter() .filter(|(k, _)| k.prefix == prefix && k.key <= key) - .map(|(k, v)| (k, v.clone())) + .map(|(k, v)| (k, v.as_slice())) .collect() } @@ -304,19 +308,19 @@ impl<'referred_data> Readable<'referred_data> for Int32Array { index: usize, ) -> Option<(&'referred_data CompositeKey, Self)> { storage - .int32_array_storage + .uint32_array_storage .iter() .nth(index) - .map(|(k, v)| (k, v.clone())) + .map(|(k, v)| (k, v.as_slice())) } fn count(storage: &Storage) -> Result> { - Ok(storage.int32_array_storage.iter().len()) + Ok(storage.uint32_array_storage.iter().len()) } fn contains(prefix: &str, key: KeyWrapper, storage: &'referred_data Storage) -> bool { storage - .int32_array_storage + .uint32_array_storage .get(&CompositeKey { prefix: prefix.to_string(), key, @@ -1063,8 +1067,8 @@ pub struct StorageBuilder { f32_storage: Arc>>>, // Roaring Bitmap Value roaring_bitmap_storage: Arc>>>, - // Int32 Array Value - int32_array_storage: Arc>>>, + // UInt32 Array Value + uint32_array_storage: Arc>>>>, // Data Record Fields data_record_id_storage: Arc>>>, data_record_embedding_storage: Arc>>>>, @@ -1082,8 +1086,8 @@ pub struct Storage { f32_storage: Arc>, // Roaring Bitmap Value roaring_bitmap_storage: Arc>, - // Int32 Array Value - int32_array_storage: Arc>, + // UInt32 Array Value + uint32_array_storage: Arc>>, // Data Record Fields data_record_id_storage: Arc>, data_record_embedding_storage: Arc>>, @@ -1118,7 +1122,7 @@ impl StorageManager { u32_storage: Arc::new(RwLock::new(Some(BTreeMap::new()))), f32_storage: Arc::new(RwLock::new(Some(BTreeMap::new()))), roaring_bitmap_storage: Arc::new(RwLock::new(Some(BTreeMap::new()))), - int32_array_storage: Arc::new(RwLock::new(Some(BTreeMap::new()))), + uint32_array_storage: Arc::new(RwLock::new(Some(BTreeMap::new()))), data_record_id_storage: Arc::new(RwLock::new(Some(BTreeMap::new()))), data_record_embedding_storage: Arc::new(RwLock::new(Some(BTreeMap::new()))), id, @@ -1134,7 +1138,7 @@ impl StorageManager { let storage = Storage { bool_storage: builder.bool_storage.write().take().unwrap().into(), string_value_storage: builder.string_value_storage.write().take().unwrap().into(), - int32_array_storage: builder.int32_array_storage.write().take().unwrap().into(), + uint32_array_storage: builder.uint32_array_storage.write().take().unwrap().into(), roaring_bitmap_storage: builder .roaring_bitmap_storage .write() diff --git a/rust/blockstore/src/positional_posting_list_value.rs b/rust/blockstore/src/positional_posting_list_value.rs deleted file mode 100644 index 796f996bc13..00000000000 --- a/rust/blockstore/src/positional_posting_list_value.rs +++ /dev/null @@ -1,263 +0,0 @@ -use arrow::{ - array::{Array, AsArray, Int32Array, Int32Builder, ListArray, ListBuilder}, - datatypes::Int32Type, -}; -use chroma_error::{ChromaError, ErrorCodes}; -use std::collections::{HashMap, HashSet}; -use thiserror::Error; - -#[derive(Debug, Clone)] -pub struct PositionalPostingList { - pub doc_ids: Int32Array, - pub positions: ListArray, -} - -impl PositionalPostingList { - pub fn get_doc_ids(&self) -> Int32Array { - return self.doc_ids.clone(); - } - - pub fn get_positions_for_doc_id(&self, doc_id: i32) -> Option { - let index = self.doc_ids.values().binary_search(&doc_id).ok(); - match index { - Some(index) => { - let target_positions = self.positions.value(index); - // Int32Array is composed of a Datatype, ScalarBuffer, and a null bitmap, these are all cheap to clone since the buffer is Arc'ed - let downcast = target_positions.as_primitive::().clone(); - return Some(downcast); - } - None => None, - } - } - - pub fn size_in_bytes(&self) -> usize { - let mut size = 0; - size += self.doc_ids.len() * std::mem::size_of::(); - size += self.positions.len() * std::mem::size_of::(); - size - } -} - -#[derive(Error, Debug)] -pub enum PositionalPostingListBuilderError { - #[error("Doc ID already exists in the list")] - DocIdAlreadyExists, - #[error("Doc ID does not exist in the list")] - DocIdDoesNotExist, - #[error("Incremental positions must be sorted")] - UnsortedPosition, -} - -impl ChromaError for PositionalPostingListBuilderError { - fn code(&self) -> ErrorCodes { - match self { - PositionalPostingListBuilderError::DocIdAlreadyExists => ErrorCodes::AlreadyExists, - PositionalPostingListBuilderError::DocIdDoesNotExist => ErrorCodes::InvalidArgument, - PositionalPostingListBuilderError::UnsortedPosition => ErrorCodes::InvalidArgument, - } - } -} - -#[derive(Debug)] -pub struct PositionalPostingListBuilder { - doc_ids: HashSet, - positions: HashMap>, -} - -impl PositionalPostingListBuilder { - pub fn new() -> Self { - PositionalPostingListBuilder { - doc_ids: HashSet::new(), - positions: HashMap::new(), - } - } - - pub fn add_doc_id_and_positions( - &mut self, - doc_id: i32, - positions: Vec, - ) -> Result<(), PositionalPostingListBuilderError> { - if self.doc_ids.contains(&doc_id) { - return Err(PositionalPostingListBuilderError::DocIdAlreadyExists); - } - - self.doc_ids.insert(doc_id); - self.positions.insert(doc_id, positions); - Ok(()) - } - - pub fn delete_doc_id(&mut self, doc_id: i32) -> Result<(), PositionalPostingListBuilderError> { - self.doc_ids.remove(&doc_id); - self.positions.remove(&doc_id); - Ok(()) - } - - pub fn contains_doc_id(&self, doc_id: i32) -> bool { - self.doc_ids.contains(&doc_id) - } - - pub fn add_positions_for_doc_id( - &mut self, - doc_id: i32, - positions: Vec, - ) -> Result<(), PositionalPostingListBuilderError> { - if !self.doc_ids.contains(&doc_id) { - return Err(PositionalPostingListBuilderError::DocIdDoesNotExist); - } - - // Safe to unwrap here since this is called for >= 2nd time a token - // exists in the document. - self.positions.get_mut(&doc_id).unwrap().extend(positions); - Ok(()) - } - - pub fn build(&mut self) -> PositionalPostingList { - let mut doc_ids_builder = Int32Builder::new(); - let mut positions_builder = ListBuilder::new(Int32Builder::new()); - - let mut doc_ids_vec: Vec = self.doc_ids.drain().collect(); - doc_ids_vec.sort(); - let doc_ids_slice = doc_ids_vec.as_slice(); - doc_ids_builder.append_slice(doc_ids_slice); - let doc_ids = doc_ids_builder.finish(); - - for doc_id in doc_ids_slice.iter() { - // Get positions for the doc ID, sort them, put them into the positions_builder - let mut positions = self.positions.remove(doc_id).unwrap(); - positions.sort(); - let positions_as_some: Vec> = positions.into_iter().map(Some).collect(); - positions_builder.append_value(positions_as_some); - } - let positions = positions_builder.finish(); - - PositionalPostingList { - doc_ids: doc_ids, - positions: positions, - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_positional_posting_list_single_document() { - let mut builder = PositionalPostingListBuilder::new(); - let _res = builder.add_doc_id_and_positions(1, vec![1, 2, 3]); - let list = builder.build(); - assert_eq!(list.get_doc_ids().values()[0], 1); - assert_eq!( - list.get_positions_for_doc_id(1).unwrap(), - Int32Array::from(vec![1, 2, 3]) - ); - } - - #[test] - fn test_positional_posting_list_multiple_documents() { - let mut builder = PositionalPostingListBuilder::new(); - let _res = builder.add_doc_id_and_positions(1, vec![1, 2, 3]); - let _res = builder.add_doc_id_and_positions(2, vec![4, 5, 6]); - let list = builder.build(); - assert_eq!(list.get_doc_ids().values()[0], 1); - assert_eq!(list.get_doc_ids().values()[1], 2); - assert_eq!( - list.get_positions_for_doc_id(1).unwrap(), - Int32Array::from(vec![1, 2, 3]) - ); - assert_eq!( - list.get_positions_for_doc_id(2).unwrap(), - Int32Array::from(vec![4, 5, 6]) - ); - } - - #[test] - fn test_positional_posting_list_document_ids_sorted_after_build() { - let mut builder = PositionalPostingListBuilder::new(); - let _res = builder.add_doc_id_and_positions(2, vec![4, 5, 6]); - let _res = builder.add_doc_id_and_positions(1, vec![1, 2, 3]); - let list = builder.build(); - assert_eq!(list.get_doc_ids().values()[0], 1); - assert_eq!(list.get_doc_ids().values()[1], 2); - assert_eq!( - list.get_positions_for_doc_id(1).unwrap(), - Int32Array::from(vec![1, 2, 3]) - ); - assert_eq!( - list.get_positions_for_doc_id(2).unwrap(), - Int32Array::from(vec![4, 5, 6]) - ); - } - - #[test] - fn test_positional_posting_list_all_positions_sorted_after_build() { - let mut builder = PositionalPostingListBuilder::new(); - let _res = builder.add_doc_id_and_positions(1, vec![3, 2, 1]); - let list = builder.build(); - assert_eq!(list.get_doc_ids().values()[0], 1); - assert_eq!( - list.get_positions_for_doc_id(1).unwrap(), - Int32Array::from(vec![1, 2, 3]) - ); - } - - #[test] - fn test_positional_posting_list_incremental_build() { - let mut builder = PositionalPostingListBuilder::new(); - - let _res = builder.add_doc_id_and_positions(1, vec![1, 2, 3]); - let _res = builder.add_positions_for_doc_id(1, [4].into()); - let _res = builder.add_positions_for_doc_id(1, [5].into()); - let _res = builder.add_positions_for_doc_id(1, [6].into()); - let _res = builder.add_doc_id_and_positions(2, vec![4, 5, 6]); - let _res = builder.add_positions_for_doc_id(2, [7].into()); - - let list = builder.build(); - assert_eq!(list.get_doc_ids().values()[0], 1); - assert_eq!(list.get_doc_ids().values()[1], 2); - assert_eq!( - list.get_positions_for_doc_id(1).unwrap(), - Int32Array::from(vec![1, 2, 3, 4, 5, 6]) - ); - } - - #[test] - fn test_positional_posting_list_delete_doc_id() { - let mut builder = PositionalPostingListBuilder::new(); - - let _res = builder.add_doc_id_and_positions(1, vec![1, 2, 3]); - let _res = builder.add_doc_id_and_positions(2, vec![4, 5, 6]); - let _res = builder.delete_doc_id(1); - - let list = builder.build(); - assert_eq!(list.get_doc_ids().values()[0], 2); - assert_eq!( - list.get_positions_for_doc_id(2).unwrap(), - Int32Array::from(vec![4, 5, 6]) - ); - } - - #[test] - fn test_all_positional_posting_list_behaviors_together() { - let mut builder = PositionalPostingListBuilder::new(); - - let _res = builder.add_doc_id_and_positions(1, vec![3, 2, 1]); - let _res = builder.add_positions_for_doc_id(1, [4].into()); - let _res = builder.add_positions_for_doc_id(1, [6].into()); - let _res = builder.add_positions_for_doc_id(1, [5].into()); - let _res = builder.add_doc_id_and_positions(2, vec![5, 4, 6]); - let _res = builder.add_positions_for_doc_id(2, [7].into()); - - let list = builder.build(); - assert_eq!(list.get_doc_ids().values()[0], 1); - assert_eq!(list.get_doc_ids().values()[1], 2); - assert_eq!( - list.get_positions_for_doc_id(1).unwrap(), - Int32Array::from(vec![1, 2, 3, 4, 5, 6]) - ); - assert_eq!( - list.get_positions_for_doc_id(2).unwrap(), - Int32Array::from(vec![4, 5, 6, 7]) - ); - } -} diff --git a/rust/blockstore/src/types.rs b/rust/blockstore/src/types.rs index 7997d28a57a..2cf924b0a66 100644 --- a/rust/blockstore/src/types.rs +++ b/rust/blockstore/src/types.rs @@ -8,12 +8,11 @@ use super::memory::reader_writer::{ MemoryBlockfileFlusher, MemoryBlockfileReader, MemoryBlockfileWriter, }; use super::memory::storage::{Readable, Writeable}; -use super::positional_posting_list_value::PositionalPostingList; -use arrow::array::{Array, Int32Array}; use chroma_error::{ChromaError, ErrorCodes}; use chroma_types::DataRecord; use roaring::RoaringBitmap; use std::fmt::{Debug, Display}; +use std::mem::size_of; use thiserror::Error; #[derive(Debug, Error)] @@ -79,17 +78,15 @@ pub trait Value: Clone { fn get_size(&self) -> usize; } -// TODO: Maybe make writeable and readable traits' -// TODO: we don't need this get size -impl Value for Int32Array { +impl Value for Vec { fn get_size(&self) -> usize { - self.get_buffer_memory_size() + self.len() * size_of::() } } -impl Value for &Int32Array { +impl Value for &[u32] { fn get_size(&self) -> usize { - self.get_buffer_memory_size() + self.len() * size_of::() } } @@ -123,12 +120,6 @@ impl Value for &RoaringBitmap { } } -impl Value for PositionalPostingList { - fn get_size(&self) -> usize { - return self.size_in_bytes(); - } -} - impl<'a> Value for DataRecord<'a> { fn get_size(&self) -> usize { DataRecord::get_size(self) diff --git a/rust/index/src/fulltext/types.rs b/rust/index/src/fulltext/types.rs index 69ef7d1364a..af64ec54337 100644 --- a/rust/index/src/fulltext/types.rs +++ b/rust/index/src/fulltext/types.rs @@ -1,10 +1,6 @@ use crate::fulltext::tokenizer::ChromaTokenizer; use crate::metadata::types::MetadataIndexError; use crate::utils::{merge_sorted_vecs_conjunction, merge_sorted_vecs_disjunction}; -use arrow::array::Int32Array; -use chroma_blockstore::positional_posting_list_value::{ - PositionalPostingListBuilder, PositionalPostingListBuilderError, -}; use chroma_blockstore::{BlockfileFlusher, BlockfileReader, BlockfileWriter}; use chroma_error::{ChromaError, ErrorCodes}; use chroma_types::{BooleanOperator, WhereDocument, WhereDocumentOperator}; @@ -22,8 +18,6 @@ pub enum FullTextIndexError { EmptyValueInPositionalPostingList, #[error("Invariant violation")] InvariantViolation, - #[error("Positional posting list error: {0}")] - PositionalPostingListError(#[from] PositionalPostingListBuilderError), #[error("Blockfile write error: {0}")] BlockfileWriteError(#[from] Box), } @@ -37,7 +31,7 @@ impl ChromaError for FullTextIndexError { #[derive(Debug)] pub struct UncommittedPostings { // token -> {doc -> [start positions]} - positional_postings: HashMap, + positional_postings: HashMap>>, // (token, doc) pairs that should be deleted from storage. deleted_token_doc_pairs: HashSet<(String, i32)>, } @@ -123,7 +117,6 @@ impl<'me> FullTextIndexWriter<'me> { return Err(FullTextIndexError::InvariantViolation); } None => { - let mut builder = PositionalPostingListBuilder::new(); let results = match &self.full_text_index_reader { // Readers are uninitialized until the first compaction finishes // so there is a case when this is none hence not an error. @@ -134,18 +127,13 @@ impl<'me> FullTextIndexWriter<'me> { Err(_) => vec![], }, }; - for (doc_id, positions) in results { - let res = builder.add_doc_id_and_positions(doc_id as i32, positions); - match res { - Ok(_) => {} - Err(e) => { - return Err(FullTextIndexError::PositionalPostingListError(e)); - } - } + let mut doc_and_positions = HashMap::new(); + for result in results { + doc_and_positions.insert(result.0, result.1); } uncommitted_postings .positional_postings - .insert(token.to_string(), builder); + .insert(token.to_string(), doc_and_positions); } } Ok(()) @@ -160,7 +148,7 @@ impl<'me> FullTextIndexWriter<'me> { pub async fn add_document( &self, document: &str, - offset_id: i32, + offset_id: u32, ) -> Result<(), FullTextIndexError> { let tokens = self.encode_tokens(document); for token in tokens.get_tokens() { @@ -178,28 +166,22 @@ impl<'me> FullTextIndexWriter<'me> { let builder = uncommitted_postings .positional_postings .entry(token.text.to_string()) - .or_insert(PositionalPostingListBuilder::new()); + .or_insert(HashMap::new()); // Store starting positions of tokens. These are NOT affected by token filters. // For search, we can use the start and end positions to compute offsets to // check full string match. // // See https://docs.rs/tantivy/latest/tantivy/tokenizer/struct.Token.html - if !builder.contains_doc_id(offset_id) { + if !builder.contains_key(&offset_id) { // Casting to i32 is safe since we limit the size of the document. - match builder.add_doc_id_and_positions(offset_id, vec![token.offset_from as i32]) { - Ok(_) => {} - Err(e) => { - return Err(FullTextIndexError::PositionalPostingListError(e)); - } - } + builder.insert(offset_id, vec![token.offset_from as u32]); } else { - match builder.add_positions_for_doc_id(offset_id, vec![token.offset_from as i32]) { - Ok(_) => {} - Err(e) => { - return Err(FullTextIndexError::PositionalPostingListError(e)); - } - } + // unwrap() is safe since we already verified that the key exists. + builder + .get_mut(&offset_id) + .unwrap() + .push(token.offset_from as u32); } } Ok(()) @@ -230,30 +212,23 @@ impl<'me> FullTextIndexWriter<'me> { .positional_postings .get_mut(token.text.as_str()) { - Some(builder) => match builder.delete_doc_id(offset_id as i32) { - Ok(_) => { - // Track all the deleted (token, doc) pairs. This is needed - // to remove the old postings list for this pair from storage. + Some(builder) => { + builder.remove(&offset_id); + if builder.is_empty() { uncommitted_postings - .deleted_token_doc_pairs - .insert((token.text.clone(), offset_id as i32)); - } - Err(e) => { - // This is a fatal invariant violation: we've been asked to - // delete a document which doesn't appear in the positional posting list. - // It probably indicates data corruption of some sort. - tracing::error!( - "Error deleting doc ID from positional posting list: {:?}", - e - ); - return Err(FullTextIndexError::PositionalPostingListError(e)); + .positional_postings + .remove(token.text.as_str()); } - }, - None => { - // Invariant violation -- we just populated this. - tracing::error!("Error deleting doc ID from positional posting list"); - return Err(FullTextIndexError::InvariantViolation); + // Track all the deleted (token, doc) pairs. This is needed + // to remove the old postings list for this pair from storage. + uncommitted_postings + .deleted_token_doc_pairs + .insert((token.text.clone(), offset_id as i32)); } + // This is fine since we delete all the positions of a token + // of a document at once so the next time we encounter this token + // (at a different position) the map could be empty. + None => {} } } Ok(()) @@ -266,7 +241,7 @@ impl<'me> FullTextIndexWriter<'me> { offset_id: u32, ) -> Result<(), FullTextIndexError> { self.delete_document(old_document, offset_id).await?; - self.add_document(new_document, offset_id as i32).await?; + self.add_document(new_document, offset_id).await?; Ok(()) } @@ -278,7 +253,7 @@ impl<'me> FullTextIndexWriter<'me> { for (token, offset_id) in uncommitted_postings.deleted_token_doc_pairs.drain() { match self .posting_lists_blockfile_writer - .delete::(token.as_str(), offset_id as u32) + .delete::>(token.as_str(), offset_id as u32) .await { Ok(_) => {} @@ -289,30 +264,20 @@ impl<'me> FullTextIndexWriter<'me> { } for (key, mut value) in uncommitted_postings.positional_postings.drain() { - let built_list = value.build(); - for doc_id in built_list.doc_ids.iter() { - match doc_id { - Some(doc_id) => { - let positional_posting_list = - built_list.get_positions_for_doc_id(doc_id).unwrap(); - // Don't add if postings list is empty for this (token, doc) combo. - // This can happen with deletes. - if positional_posting_list.len() > 0 { - match self - .posting_lists_blockfile_writer - .set(key.as_str(), doc_id as u32, positional_posting_list) - .await - { - Ok(_) => {} - Err(e) => { - return Err(FullTextIndexError::BlockfileWriteError(e)); - } - } + for (doc_id, positions) in value.drain() { + // Don't add if postings list is empty for this (token, doc) combo. + // This can happen with deletes. + if positions.len() > 0 { + match self + .posting_lists_blockfile_writer + .set(key.as_str(), doc_id, positions) + .await + { + Ok(_) => {} + Err(e) => { + return Err(FullTextIndexError::BlockfileWriteError(e)); } } - None => { - panic!("Positions for doc ID not found in positional posting list -- should never happen") - } } } } @@ -356,7 +321,7 @@ impl<'me> FullTextIndexWriter<'me> { // TODO should we be `await?`ing these? Or can we just return the futures? let posting_lists_blockfile_flusher = self .posting_lists_blockfile_writer - .commit::()?; + .commit::>()?; let frequencies_blockfile_flusher = self.frequencies_blockfile_writer.commit::()?; Ok(FullTextIndexFlusher { @@ -375,7 +340,7 @@ impl FullTextIndexFlusher { pub async fn flush(self) -> Result<(), FullTextIndexError> { match self .posting_lists_blockfile_flusher - .flush::() + .flush::>() .await { Ok(_) => {} @@ -407,14 +372,14 @@ impl FullTextIndexFlusher { #[derive(Clone)] pub struct FullTextIndexReader<'me> { - posting_lists_blockfile_reader: BlockfileReader<'me, u32, Int32Array>, + posting_lists_blockfile_reader: BlockfileReader<'me, u32, &'me [u32]>, frequencies_blockfile_reader: BlockfileReader<'me, u32, u32>, tokenizer: Arc>>, } impl<'me> FullTextIndexReader<'me> { pub fn new( - posting_lists_blockfile_reader: BlockfileReader<'me, u32, Int32Array>, + posting_lists_blockfile_reader: BlockfileReader<'me, u32, &'me [u32]>, frequencies_blockfile_reader: BlockfileReader<'me, u32, u32>, tokenizer: Box, ) -> Self { @@ -469,7 +434,7 @@ impl<'me> FullTextIndexReader<'me> { // Populate initial candidates with the least-frequent token's posting list. // doc ID -> possible starting locations for the query. - let mut candidates: HashMap> = HashMap::new(); + let mut candidates: HashMap> = HashMap::new(); let first_token = token_frequencies[0].0.as_str(); let first_token_positional_posting_list = self .posting_lists_blockfile_reader @@ -477,8 +442,7 @@ impl<'me> FullTextIndexReader<'me> { .await .unwrap(); for (_, doc_id, positions) in first_token_positional_posting_list.iter() { - let positions_vec: Vec = positions.iter().map(|x| x.unwrap()).collect(); - candidates.insert(*doc_id, positions_vec); + candidates.insert(*doc_id, positions.to_vec()); } // Iterate through the rest of the tokens, intersecting the posting lists with the candidates. @@ -498,7 +462,7 @@ impl<'me> FullTextIndexReader<'me> { // .find(|t| t.text == *token) // .unwrap() // .offset_from as i32; - let mut new_candidates: HashMap> = HashMap::new(); + let mut new_candidates: HashMap> = HashMap::new(); for (doc_id, positions) in candidates.iter() { let mut new_positions = vec![]; for position in positions { @@ -510,19 +474,8 @@ impl<'me> FullTextIndexReader<'me> { .map(|x| &x.2) { for pos in positions.iter() { - match pos { - None => { - // This should never happen since we only store positions for the doc_id - // in the positional posting list. - return Err( - FullTextIndexError::EmptyValueInPositionalPostingList, - ); - } - Some(pos) => { - if pos == position + token_offset { - new_positions.push(*position); - } - } + if *pos == position + token_offset { + new_positions.push(*position); } } } @@ -550,15 +503,14 @@ impl<'me> FullTextIndexReader<'me> { async fn get_all_results_for_token( &self, token: &str, - ) -> Result)>, FullTextIndexError> { + ) -> Result)>, FullTextIndexError> { let positional_posting_list = self .posting_lists_blockfile_reader .get_by_prefix(token) .await?; let mut results = vec![]; for (_, doc_id, positions) in positional_posting_list.iter() { - let positions_vec: Vec = positions.iter().map(|x| x.unwrap()).collect(); - results.push((*doc_id, positions_vec)); + results.push((*doc_id, positions.to_vec())); } Ok(results) } @@ -640,7 +592,7 @@ mod tests { #[test] fn test_new_writer() { let provider = BlockfileProvider::new_memory(); - let pl_blockfile_writer = provider.create::().unwrap(); + let pl_blockfile_writer = provider.create::>().unwrap(); let freq_blockfile_writer = provider.create::().unwrap(); let tokenizer = Box::new(TantivyChromaTokenizer::new(Box::new( NgramTokenizer::new(1, 1, false).unwrap(), @@ -653,7 +605,7 @@ mod tests { async fn test_new_writer_then_reader() { let provider = BlockfileProvider::new_memory(); let freq_blockfile_writer = provider.create::().unwrap(); - let pl_blockfile_writer = provider.create::().unwrap(); + let pl_blockfile_writer = provider.create::>().unwrap(); let freq_blockfile_id = freq_blockfile_writer.id(); let pl_blockfile_id = pl_blockfile_writer.id(); @@ -668,7 +620,7 @@ mod tests { let freq_blockfile_reader = provider.open::(&freq_blockfile_id).await.unwrap(); let pl_blockfile_reader = provider - .open::(&pl_blockfile_id) + .open::(&pl_blockfile_id) .await .unwrap(); let tokenizer = Box::new(TantivyChromaTokenizer::new(Box::new( @@ -680,7 +632,7 @@ mod tests { #[tokio::test] async fn test_index_and_search_single_document() { let provider = BlockfileProvider::new_memory(); - let pl_blockfile_writer = provider.create::().unwrap(); + let pl_blockfile_writer = provider.create::>().unwrap(); let freq_blockfile_writer = provider.create::().unwrap(); let pl_blockfile_id = pl_blockfile_writer.id(); let freq_blockfile_id = freq_blockfile_writer.id(); @@ -697,7 +649,7 @@ mod tests { let freq_blockfile_reader = provider.open::(&freq_blockfile_id).await.unwrap(); let pl_blockfile_reader = provider - .open::(&pl_blockfile_id) + .open::(&pl_blockfile_id) .await .unwrap(); let tokenizer = Box::new(TantivyChromaTokenizer::new(Box::new( @@ -719,7 +671,7 @@ mod tests { #[tokio::test] async fn test_repeating_character_in_query() { let provider = BlockfileProvider::new_memory(); - let pl_blockfile_writer = provider.create::().unwrap(); + let pl_blockfile_writer = provider.create::>().unwrap(); let freq_blockfile_writer = provider.create::().unwrap(); let pl_blockfile_id = pl_blockfile_writer.id(); let freq_blockfile_id = freq_blockfile_writer.id(); @@ -736,7 +688,7 @@ mod tests { let freq_blockfile_reader = provider.open::(&freq_blockfile_id).await.unwrap(); let pl_blockfile_reader = provider - .open::(&pl_blockfile_id) + .open::(&pl_blockfile_id) .await .unwrap(); let tokenizer = Box::new(TantivyChromaTokenizer::new(Box::new( @@ -752,7 +704,7 @@ mod tests { #[tokio::test] async fn test_query_of_repeating_character() { let provider = BlockfileProvider::new_memory(); - let pl_blockfile_writer = provider.create::().unwrap(); + let pl_blockfile_writer = provider.create::>().unwrap(); let freq_blockfile_writer = provider.create::().unwrap(); let pl_blockfile_id = pl_blockfile_writer.id(); let freq_blockfile_id = freq_blockfile_writer.id(); @@ -770,7 +722,7 @@ mod tests { let freq_blockfile_reader = provider.open::(&freq_blockfile_id).await.unwrap(); let pl_blockfile_reader = provider - .open::(&pl_blockfile_id) + .open::(&pl_blockfile_id) .await .unwrap(); let tokenizer = Box::new(TantivyChromaTokenizer::new(Box::new( @@ -786,7 +738,7 @@ mod tests { #[tokio::test] async fn test_repeating_character_in_document() { let provider = BlockfileProvider::new_memory(); - let pl_blockfile_writer = provider.create::().unwrap(); + let pl_blockfile_writer = provider.create::>().unwrap(); let freq_blockfile_writer = provider.create::().unwrap(); let pl_blockfile_id = pl_blockfile_writer.id(); let freq_blockfile_id = freq_blockfile_writer.id(); @@ -803,7 +755,7 @@ mod tests { let freq_blockfile_reader = provider.open::(&freq_blockfile_id).await.unwrap(); let pl_blockfile_reader = provider - .open::(&pl_blockfile_id) + .open::(&pl_blockfile_id) .await .unwrap(); let tokenizer = Box::new(TantivyChromaTokenizer::new(Box::new( @@ -819,7 +771,7 @@ mod tests { #[tokio::test] async fn test_search_absent_token() { let provider = BlockfileProvider::new_memory(); - let pl_blockfile_writer = provider.create::().unwrap(); + let pl_blockfile_writer = provider.create::>().unwrap(); let freq_blockfile_writer = provider.create::().unwrap(); let pl_blockfile_id = pl_blockfile_writer.id(); let freq_blockfile_id = freq_blockfile_writer.id(); @@ -836,7 +788,7 @@ mod tests { let freq_blockfile_reader = provider.open::(&freq_blockfile_id).await.unwrap(); let pl_blockfile_reader = provider - .open::(&pl_blockfile_id) + .open::(&pl_blockfile_id) .await .unwrap(); let tokenizer = Box::new(TantivyChromaTokenizer::new(Box::new( @@ -852,7 +804,7 @@ mod tests { #[tokio::test] async fn test_multiple_candidates_within_document() { let provider = BlockfileProvider::new_memory(); - let pl_blockfile_writer = provider.create::().unwrap(); + let pl_blockfile_writer = provider.create::>().unwrap(); let freq_blockfile_writer = provider.create::().unwrap(); let pl_blockfile_id = pl_blockfile_writer.id(); let freq_blockfile_id = freq_blockfile_writer.id(); @@ -873,7 +825,7 @@ mod tests { let freq_blockfile_reader = provider.open::(&freq_blockfile_id).await.unwrap(); let pl_blockfile_reader = provider - .open::(&pl_blockfile_id) + .open::(&pl_blockfile_id) .await .unwrap(); let tokenizer = Box::new(TantivyChromaTokenizer::new(Box::new( @@ -893,7 +845,7 @@ mod tests { #[tokio::test] async fn test_multiple_simple_documents() { let provider = BlockfileProvider::new_memory(); - let pl_blockfile_writer = provider.create::().unwrap(); + let pl_blockfile_writer = provider.create::>().unwrap(); let freq_blockfile_writer = provider.create::().unwrap(); let pl_blockfile_id = pl_blockfile_writer.id(); let freq_blockfile_id = freq_blockfile_writer.id(); @@ -911,7 +863,7 @@ mod tests { let freq_blockfile_reader = provider.open::(&freq_blockfile_id).await.unwrap(); let pl_blockfile_reader = provider - .open::(&pl_blockfile_id) + .open::(&pl_blockfile_id) .await .unwrap(); let tokenizer = Box::new(TantivyChromaTokenizer::new(Box::new( @@ -931,7 +883,7 @@ mod tests { #[tokio::test] async fn test_multiple_complex_documents() { let provider = BlockfileProvider::new_memory(); - let pl_blockfile_writer = provider.create::().unwrap(); + let pl_blockfile_writer = provider.create::>().unwrap(); let freq_blockfile_writer = provider.create::().unwrap(); let pl_blockfile_id = pl_blockfile_writer.id(); let freq_blockfile_id = freq_blockfile_writer.id(); @@ -951,7 +903,7 @@ mod tests { let freq_blockfile_reader = provider.open::(&freq_blockfile_id).await.unwrap(); let pl_blockfile_reader = provider - .open::(&pl_blockfile_id) + .open::(&pl_blockfile_id) .await .unwrap(); let tokenizer = Box::new(TantivyChromaTokenizer::new(Box::new( @@ -980,7 +932,7 @@ mod tests { #[tokio::test] async fn test_index_multiple_character_repeating() { let provider = BlockfileProvider::new_memory(); - let pl_blockfile_writer = provider.create::().unwrap(); + let pl_blockfile_writer = provider.create::>().unwrap(); let freq_blockfile_writer = provider.create::().unwrap(); let pl_blockfile_id = pl_blockfile_writer.id(); let freq_blockfile_id = freq_blockfile_writer.id(); @@ -1004,7 +956,7 @@ mod tests { let freq_blockfile_reader = provider.open::(&freq_blockfile_id).await.unwrap(); let pl_blockfile_reader = provider - .open::(&pl_blockfile_id) + .open::(&pl_blockfile_id) .await .unwrap(); let tokenizer = Box::new(TantivyChromaTokenizer::new(Box::new( @@ -1029,7 +981,7 @@ mod tests { #[tokio::test] async fn test_index_special_characters() { let provider = BlockfileProvider::new_memory(); - let pl_blockfile_writer = provider.create::().unwrap(); + let pl_blockfile_writer = provider.create::>().unwrap(); let freq_blockfile_writer = provider.create::().unwrap(); let pl_blockfile_id = pl_blockfile_writer.id(); let freq_blockfile_id = freq_blockfile_writer.id(); @@ -1051,7 +1003,7 @@ mod tests { let freq_blockfile_reader = provider.open::(&freq_blockfile_id).await.unwrap(); let pl_blockfile_reader = provider - .open::(&pl_blockfile_id) + .open::(&pl_blockfile_id) .await .unwrap(); let tokenizer = Box::new(TantivyChromaTokenizer::new(Box::new( @@ -1074,7 +1026,7 @@ mod tests { #[tokio::test] async fn test_get_frequencies_for_token() { let provider = BlockfileProvider::new_memory(); - let pl_blockfile_writer = provider.create::().unwrap(); + let pl_blockfile_writer = provider.create::>().unwrap(); let freq_blockfile_writer = provider.create::().unwrap(); let pl_blockfile_id = pl_blockfile_writer.id(); let freq_blockfile_id = freq_blockfile_writer.id(); @@ -1095,7 +1047,7 @@ mod tests { let freq_blockfile_reader = provider.open::(&freq_blockfile_id).await.unwrap(); let pl_blockfile_reader = provider - .open::(&pl_blockfile_id) + .open::(&pl_blockfile_id) .await .unwrap(); let tokenizer = Box::new(TantivyChromaTokenizer::new(Box::new( @@ -1117,7 +1069,7 @@ mod tests { #[tokio::test] async fn test_get_all_results_for_token() { let provider = BlockfileProvider::new_memory(); - let pl_blockfile_writer = provider.create::().unwrap(); + let pl_blockfile_writer = provider.create::>().unwrap(); let freq_blockfile_writer = provider.create::().unwrap(); let pl_blockfile_id = pl_blockfile_writer.id(); let freq_blockfile_id = freq_blockfile_writer.id(); @@ -1138,7 +1090,7 @@ mod tests { let freq_blockfile_reader = provider.open::(&freq_blockfile_id).await.unwrap(); let pl_blockfile_reader = provider - .open::(&pl_blockfile_id) + .open::(&pl_blockfile_id) .await .unwrap(); let tokenizer = Box::new(TantivyChromaTokenizer::new(Box::new( @@ -1160,7 +1112,7 @@ mod tests { #[tokio::test] async fn test_update_document() { let provider = BlockfileProvider::new_memory(); - let pl_blockfile_writer = provider.create::().unwrap(); + let pl_blockfile_writer = provider.create::>().unwrap(); let freq_blockfile_writer = provider.create::().unwrap(); let pl_blockfile_id = pl_blockfile_writer.id(); let freq_blockfile_id = freq_blockfile_writer.id(); @@ -1185,7 +1137,7 @@ mod tests { let freq_blockfile_reader = provider.open::(&freq_blockfile_id).await.unwrap(); let pl_blockfile_reader = provider - .open::(&pl_blockfile_id) + .open::(&pl_blockfile_id) .await .unwrap(); let tokenizer = Box::new(TantivyChromaTokenizer::new(Box::new( @@ -1205,7 +1157,7 @@ mod tests { #[tokio::test] async fn test_delete_document() { let provider = BlockfileProvider::new_memory(); - let pl_blockfile_writer = provider.create::().unwrap(); + let pl_blockfile_writer = provider.create::>().unwrap(); let freq_blockfile_writer = provider.create::().unwrap(); let pl_blockfile_id = pl_blockfile_writer.id(); let freq_blockfile_id = freq_blockfile_writer.id(); @@ -1227,7 +1179,7 @@ mod tests { let freq_blockfile_reader = provider.open::(&freq_blockfile_id).await.unwrap(); let pl_blockfile_reader = provider - .open::(&pl_blockfile_id) + .open::(&pl_blockfile_id) .await .unwrap(); let tokenizer = Box::new(TantivyChromaTokenizer::new(Box::new( diff --git a/rust/worker/src/segment/metadata_segment.rs b/rust/worker/src/segment/metadata_segment.rs index c99cf2688e3..a00a0c6561a 100644 --- a/rust/worker/src/segment/metadata_segment.rs +++ b/rust/worker/src/segment/metadata_segment.rs @@ -125,21 +125,20 @@ impl<'me> MetadataSegmentWriter<'me> { return Err(MetadataSegmentError::UuidParseError(pls_uuid.to_string())) } }; - let pls_writer = - match blockfile_provider.fork::(&pls_uuid).await { - Ok(writer) => writer, - Err(e) => return Err(MetadataSegmentError::BlockfileError(*e)), - }; - let pls_reader = - match blockfile_provider.open::(&pls_uuid).await { - Ok(reader) => reader, - Err(e) => return Err(MetadataSegmentError::BlockfileOpenError(*e)), - }; + let pls_writer = match blockfile_provider.fork::>(&pls_uuid).await + { + Ok(writer) => writer, + Err(e) => return Err(MetadataSegmentError::BlockfileError(*e)), + }; + let pls_reader = match blockfile_provider.open::(&pls_uuid).await { + Ok(reader) => reader, + Err(e) => return Err(MetadataSegmentError::BlockfileOpenError(*e)), + }; (pls_writer, Some(pls_reader)) } None => return Err(MetadataSegmentError::EmptyPathVector), }, - None => match blockfile_provider.create::() { + None => match blockfile_provider.create::>() { Ok(writer) => (writer, None), Err(e) => return Err(MetadataSegmentError::BlockfileError(*e)), }, @@ -594,7 +593,7 @@ impl<'log_records> SegmentWriter<'log_records> for MetadataSegmentWriter<'_> { Some(document) => match &self.full_text_index_writer { Some(writer) => { let _ = writer - .add_document(document, segment_offset_id as i32) + .add_document(document, segment_offset_id) .await; } None => panic!( @@ -712,7 +711,7 @@ impl<'log_records> SegmentWriter<'log_records> for MetadataSegmentWriter<'_> { } // Previous version of record does not contain document string. None => match writer - .add_document(doc, segment_offset_id as i32) + .add_document(doc, segment_offset_id) .await { Ok(_) => {} @@ -800,7 +799,7 @@ impl<'log_records> SegmentWriter<'log_records> for MetadataSegmentWriter<'_> { Some(document) => match &self.full_text_index_writer { Some(writer) => { let _ = writer - .add_document(document, segment_offset_id as i32) + .add_document(document, segment_offset_id) .await; } None => panic!( @@ -984,11 +983,10 @@ impl MetadataSegmentReader<'_> { return Err(MetadataSegmentError::UuidParseError(pls_uuid.to_string())) } }; - let pls_reader = - match blockfile_provider.open::(&pls_uuid).await { - Ok(reader) => Some(reader), - Err(e) => return Err(MetadataSegmentError::BlockfileOpenError(*e)), - }; + let pls_reader = match blockfile_provider.open::(&pls_uuid).await { + Ok(reader) => Some(reader), + Err(e) => return Err(MetadataSegmentError::BlockfileOpenError(*e)), + }; pls_reader } None => None,