From 537f29aa3e402c68a932929335591dd3793847a7 Mon Sep 17 00:00:00 2001 From: Drew Kim Date: Thu, 16 Jan 2025 12:01:28 -0800 Subject: [PATCH] [ENH] Update compactor to flush total records upon compaction (#3483) ## Description of changes Updates the compactor to calculate the total records per collection and flush to the sysdb upon every compaction. *Summarize the changes made by this PR.* - New functionality - `FlushCollectionCompaction` struct includes `TotalRecordsPostCompaction` - `SysDB` populates the `total_records_post_compaction` column when receiving a flush - `ArrowBlockfileFlusher` contains a new attribute, `total_keys` - `ArrowUnorderedBlockfileWriter` sums the total keys using the `SparseIndexWriter` and returns an `ArrowBlockfileFlusher` with the summed count - `RegisterInput` contains a new attribute, `total_records_post_compaction ` - If `CompactOrchestrator`, when handling `CommitSegmentWriterOutput`, receives a `ChromaSegmentFlusher::RecordSegment`, it reads `total_keys()` and sets it as an attribute on itself. - `ChromaSegmentFlusher::RecordSegment` has `total_keys()` through its `ArrowBlockfileFlusher` - `CompactOrchestrator` sends its `num_records_last_compaction` value to a `RegisterInput` to be flushed to the `SysDB` ## Test plan *How are these changes tested?* - [x] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust - [x] Tested locally and confirmed that compaction correctly updates the column in the `SysDB` ## Documentation Changes *Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs repository](https://github.com/chroma-core/docs)?* --- go/pkg/sysdb/coordinator/model/collection.go | 11 +-- go/pkg/sysdb/coordinator/table_catalog.go | 2 +- go/pkg/sysdb/grpc/collection_service.go | 11 +-- go/pkg/sysdb/metastore/db/dao/collection.go | 6 +- .../sysdb/metastore/db/dao/collection_test.go | 11 +-- .../sysdb/metastore/db/dbmodel/collection.go | 2 +- .../db/dbmodel/mocks/ICollectionDb.go | 20 ++--- .../metastore/db/dbmodel/mocks/ISegmentDb.go | 30 +++++++ idl/chromadb/proto/coordinator.proto | 1 + rust/blockstore/src/arrow/blockfile.rs | 86 ++++++++++++++++++- rust/blockstore/src/arrow/flusher.rs | 7 ++ .../src/arrow/ordered_blockfile_writer.rs | 86 ++++++++++++++++++- rust/blockstore/src/types/flusher.rs | 7 ++ rust/sysdb/src/sysdb.rs | 5 ++ rust/sysdb/src/test_sysdb.rs | 2 + .../src/execution/operators/register.rs | 16 +++- .../src/execution/orchestration/compact.rs | 12 ++- rust/worker/src/segment/metadata_segment.rs | 2 + rust/worker/src/segment/record_segment.rs | 4 + 19 files changed, 286 insertions(+), 35 deletions(-) diff --git a/go/pkg/sysdb/coordinator/model/collection.go b/go/pkg/sysdb/coordinator/model/collection.go index 45bc882a98b..1263bfd0dee 100644 --- a/go/pkg/sysdb/coordinator/model/collection.go +++ b/go/pkg/sysdb/coordinator/model/collection.go @@ -50,11 +50,12 @@ type UpdateCollection struct { } type FlushCollectionCompaction struct { - ID types.UniqueID - TenantID string - LogPosition int64 - CurrentCollectionVersion int32 - FlushSegmentCompactions []*FlushSegmentCompaction + ID types.UniqueID + TenantID string + LogPosition int64 + CurrentCollectionVersion int32 + FlushSegmentCompactions []*FlushSegmentCompaction + TotalRecordsPostCompaction uint64 } type FlushCollectionInfo struct { diff --git a/go/pkg/sysdb/coordinator/table_catalog.go b/go/pkg/sysdb/coordinator/table_catalog.go index afd8796f86c..c6f5cfe4a88 100644 --- a/go/pkg/sysdb/coordinator/table_catalog.go +++ b/go/pkg/sysdb/coordinator/table_catalog.go @@ -864,7 +864,7 @@ func (tc *Catalog) FlushCollectionCompaction(ctx context.Context, flushCollectio } // update collection log position and version - collectionVersion, err := tc.metaDomain.CollectionDb(txCtx).UpdateLogPositionAndVersion(flushCollectionCompaction.ID.String(), flushCollectionCompaction.LogPosition, flushCollectionCompaction.CurrentCollectionVersion) + collectionVersion, err := tc.metaDomain.CollectionDb(txCtx).UpdateLogPositionVersionAndTotalRecords(flushCollectionCompaction.ID.String(), flushCollectionCompaction.LogPosition, flushCollectionCompaction.CurrentCollectionVersion, flushCollectionCompaction.TotalRecordsPostCompaction) if err != nil { return err } diff --git a/go/pkg/sysdb/grpc/collection_service.go b/go/pkg/sysdb/grpc/collection_service.go index 7d7abd69647..055ee972857 100644 --- a/go/pkg/sysdb/grpc/collection_service.go +++ b/go/pkg/sysdb/grpc/collection_service.go @@ -325,11 +325,12 @@ func (s *Server) FlushCollectionCompaction(ctx context.Context, req *coordinator }) } FlushCollectionCompaction := &model.FlushCollectionCompaction{ - ID: collectionID, - TenantID: req.TenantId, - LogPosition: req.LogPosition, - CurrentCollectionVersion: req.CollectionVersion, - FlushSegmentCompactions: segmentCompactionInfo, + ID: collectionID, + TenantID: req.TenantId, + LogPosition: req.LogPosition, + CurrentCollectionVersion: req.CollectionVersion, + FlushSegmentCompactions: segmentCompactionInfo, + TotalRecordsPostCompaction: req.TotalRecordsPostCompaction, } flushCollectionInfo, err := s.coordinator.FlushCollectionCompaction(ctx, FlushCollectionCompaction) if err != nil { diff --git a/go/pkg/sysdb/metastore/db/dao/collection.go b/go/pkg/sysdb/metastore/db/dao/collection.go index 0c7483a0bfa..e911f6a6b31 100644 --- a/go/pkg/sysdb/metastore/db/dao/collection.go +++ b/go/pkg/sysdb/metastore/db/dao/collection.go @@ -211,8 +211,8 @@ func (s *collectionDb) Update(in *dbmodel.Collection) error { return nil } -func (s *collectionDb) UpdateLogPositionAndVersion(collectionID string, logPosition int64, currentCollectionVersion int32) (int32, error) { - log.Info("update log position and version", zap.String("collectionID", collectionID), zap.Int64("logPosition", logPosition), zap.Int32("currentCollectionVersion", currentCollectionVersion)) +func (s *collectionDb) UpdateLogPositionVersionAndTotalRecords(collectionID string, logPosition int64, currentCollectionVersion int32, totalRecordsPostCompaction uint64) (int32, error) { + log.Info("update log position, version, and total records post compaction", zap.String("collectionID", collectionID), zap.Int64("logPosition", logPosition), zap.Int32("currentCollectionVersion", currentCollectionVersion), zap.Uint64("totalRecords", totalRecordsPostCompaction)) var collection dbmodel.Collection // We use select for update to ensure no lost update happens even for isolation level read committed or below // https://patrick.engineering/posts/postgres-internals/ @@ -232,7 +232,7 @@ func (s *collectionDb) UpdateLogPositionAndVersion(collectionID string, logPosit } version := currentCollectionVersion + 1 - err = s.db.Model(&dbmodel.Collection{}).Where("id = ?", collectionID).Updates(map[string]interface{}{"log_position": logPosition, "version": version}).Error + err = s.db.Model(&dbmodel.Collection{}).Where("id = ?", collectionID).Updates(map[string]interface{}{"log_position": logPosition, "version": version, "total_records_post_compaction": totalRecordsPostCompaction}).Error if err != nil { return 0, err } diff --git a/go/pkg/sysdb/metastore/db/dao/collection_test.go b/go/pkg/sysdb/metastore/db/dao/collection_test.go index 63df8dc3f54..a2cdae6c0de 100644 --- a/go/pkg/sysdb/metastore/db/dao/collection_test.go +++ b/go/pkg/sysdb/metastore/db/dao/collection_test.go @@ -121,7 +121,7 @@ func (suite *CollectionDbTestSuite) TestCollectionDb_GetCollections() { suite.NoError(err) } -func (suite *CollectionDbTestSuite) TestCollectionDb_UpdateLogPositionAndVersion() { +func (suite *CollectionDbTestSuite) TestCollectionDb_UpdateLogPositionVersionAndTotalRecords() { collectionName := "test_collection_get_collections" collectionID, _ := CreateTestCollection(suite.db, collectionName, 128, suite.databaseId) // verify default values @@ -132,22 +132,23 @@ func (suite *CollectionDbTestSuite) TestCollectionDb_UpdateLogPositionAndVersion suite.Equal(int32(0), collections[0].Collection.Version) // update log position and version - version, err := suite.collectionDb.UpdateLogPositionAndVersion(collectionID, int64(10), 0) + version, err := suite.collectionDb.UpdateLogPositionVersionAndTotalRecords(collectionID, int64(10), 0, uint64(100)) suite.NoError(err) suite.Equal(int32(1), version) collections, _ = suite.collectionDb.GetCollections(&collectionID, nil, "", "", nil, nil) suite.Len(collections, 1) suite.Equal(int64(10), collections[0].Collection.LogPosition) suite.Equal(int32(1), collections[0].Collection.Version) + suite.Equal(uint64(100), collections[0].Collection.TotalRecordsPostCompaction) // invalid log position - _, err = suite.collectionDb.UpdateLogPositionAndVersion(collectionID, int64(5), 0) + _, err = suite.collectionDb.UpdateLogPositionVersionAndTotalRecords(collectionID, int64(5), 0, uint64(100)) suite.Error(err, "collection log position Stale") // invalid version - _, err = suite.collectionDb.UpdateLogPositionAndVersion(collectionID, int64(20), 0) + _, err = suite.collectionDb.UpdateLogPositionVersionAndTotalRecords(collectionID, int64(20), 0, uint64(100)) suite.Error(err, "collection version invalid") - _, err = suite.collectionDb.UpdateLogPositionAndVersion(collectionID, int64(20), 3) + _, err = suite.collectionDb.UpdateLogPositionVersionAndTotalRecords(collectionID, int64(20), 3, uint64(100)) suite.Error(err, "collection version invalid") //clean up diff --git a/go/pkg/sysdb/metastore/db/dbmodel/collection.go b/go/pkg/sysdb/metastore/db/dbmodel/collection.go index 7d132bdb121..4008d23226e 100644 --- a/go/pkg/sysdb/metastore/db/dbmodel/collection.go +++ b/go/pkg/sysdb/metastore/db/dbmodel/collection.go @@ -40,6 +40,6 @@ type ICollectionDb interface { Insert(in *Collection) error Update(in *Collection) error DeleteAll() error - UpdateLogPositionAndVersion(collectionID string, logPosition int64, currentCollectionVersion int32) (int32, error) + UpdateLogPositionVersionAndTotalRecords(collectionID string, logPosition int64, currentCollectionVersion int32, totalRecordsPostCompaction uint64) (int32, error) GetCollectionEntry(collectionID *string, databaseName *string) (*Collection, error) } diff --git a/go/pkg/sysdb/metastore/db/dbmodel/mocks/ICollectionDb.go b/go/pkg/sysdb/metastore/db/dbmodel/mocks/ICollectionDb.go index 21a3596f8ba..90f8bd6c33b 100644 --- a/go/pkg/sysdb/metastore/db/dbmodel/mocks/ICollectionDb.go +++ b/go/pkg/sysdb/metastore/db/dbmodel/mocks/ICollectionDb.go @@ -184,27 +184,27 @@ func (_m *ICollectionDb) Update(in *dbmodel.Collection) error { return r0 } -// UpdateLogPositionAndVersion provides a mock function with given fields: collectionID, logPosition, currentCollectionVersion -func (_m *ICollectionDb) UpdateLogPositionAndVersion(collectionID string, logPosition int64, currentCollectionVersion int32) (int32, error) { - ret := _m.Called(collectionID, logPosition, currentCollectionVersion) +// UpdateLogPositionVersionAndTotalRecords provides a mock function with given fields: collectionID, logPosition, currentCollectionVersion, totalRecordsPostCompaction +func (_m *ICollectionDb) UpdateLogPositionVersionAndTotalRecords(collectionID string, logPosition int64, currentCollectionVersion int32, totalRecordsPostCompaction uint64) (int32, error) { + ret := _m.Called(collectionID, logPosition, currentCollectionVersion, totalRecordsPostCompaction) if len(ret) == 0 { - panic("no return value specified for UpdateLogPositionAndVersion") + panic("no return value specified for UpdateLogPositionVersionAndTotalRecords") } var r0 int32 var r1 error - if rf, ok := ret.Get(0).(func(string, int64, int32) (int32, error)); ok { - return rf(collectionID, logPosition, currentCollectionVersion) + if rf, ok := ret.Get(0).(func(string, int64, int32, uint64) (int32, error)); ok { + return rf(collectionID, logPosition, currentCollectionVersion, totalRecordsPostCompaction) } - if rf, ok := ret.Get(0).(func(string, int64, int32) int32); ok { - r0 = rf(collectionID, logPosition, currentCollectionVersion) + if rf, ok := ret.Get(0).(func(string, int64, int32, uint64) int32); ok { + r0 = rf(collectionID, logPosition, currentCollectionVersion, totalRecordsPostCompaction) } else { r0 = ret.Get(0).(int32) } - if rf, ok := ret.Get(1).(func(string, int64, int32) error); ok { - r1 = rf(collectionID, logPosition, currentCollectionVersion) + if rf, ok := ret.Get(1).(func(string, int64, int32, uint64) error); ok { + r1 = rf(collectionID, logPosition, currentCollectionVersion, totalRecordsPostCompaction) } else { r1 = ret.Error(1) } diff --git a/go/pkg/sysdb/metastore/db/dbmodel/mocks/ISegmentDb.go b/go/pkg/sysdb/metastore/db/dbmodel/mocks/ISegmentDb.go index 08087c86c37..040bf1de60c 100644 --- a/go/pkg/sysdb/metastore/db/dbmodel/mocks/ISegmentDb.go +++ b/go/pkg/sysdb/metastore/db/dbmodel/mocks/ISegmentDb.go @@ -82,6 +82,36 @@ func (_m *ISegmentDb) GetSegments(id types.UniqueID, segmentType *string, scope return r0, r1 } +// GetSegmentsByCollectionID provides a mock function with given fields: collectionID +func (_m *ISegmentDb) GetSegmentsByCollectionID(collectionID string) ([]*dbmodel.Segment, error) { + ret := _m.Called(collectionID) + + if len(ret) == 0 { + panic("no return value specified for GetSegmentsByCollectionID") + } + + var r0 []*dbmodel.Segment + var r1 error + if rf, ok := ret.Get(0).(func(string) ([]*dbmodel.Segment, error)); ok { + return rf(collectionID) + } + if rf, ok := ret.Get(0).(func(string) []*dbmodel.Segment); ok { + r0 = rf(collectionID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*dbmodel.Segment) + } + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(collectionID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // Insert provides a mock function with given fields: _a0 func (_m *ISegmentDb) Insert(_a0 *dbmodel.Segment) error { ret := _m.Called(_a0) diff --git a/idl/chromadb/proto/coordinator.proto b/idl/chromadb/proto/coordinator.proto index 30c6613dff7..df7fb26629b 100644 --- a/idl/chromadb/proto/coordinator.proto +++ b/idl/chromadb/proto/coordinator.proto @@ -225,6 +225,7 @@ message FlushCollectionCompactionRequest { int64 log_position = 3; int32 collection_version = 4; repeated FlushSegmentCompactionInfo segment_compaction_info = 5; + uint64 total_records_post_compaction = 6; } message FlushCollectionCompactionResponse { diff --git a/rust/blockstore/src/arrow/blockfile.rs b/rust/blockstore/src/arrow/blockfile.rs index 6be8b5a2fe3..3c831d9f92d 100644 --- a/rust/blockstore/src/arrow/blockfile.rs +++ b/rust/blockstore/src/arrow/blockfile.rs @@ -133,12 +133,23 @@ impl ArrowUnorderedBlockfileWriter { Box::new(ArrowBlockfileError::MigrationError(e)) as Box })?; + let count = self + .root + .sparse_index + .data + .lock() + .counts + .values() + .map(|&x| x as u64) + .sum::(); + let flusher = ArrowBlockfileFlusher::new( self.block_manager, self.root_manager, blocks, self.root, self.id, + count, ); Ok(flusher) @@ -705,7 +716,7 @@ mod tests { use uuid::Uuid; #[tokio::test] - async fn test_count() { + async fn test_reader_count() { let tmp_dir = tempfile::tempdir().unwrap(); let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap())); let block_cache = new_cache_for_test(); @@ -744,6 +755,79 @@ mod tests { } } + #[tokio::test] + async fn test_writer_count() { + let tmp_dir = tempfile::tempdir().unwrap(); + let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap())); + let block_cache = new_cache_for_test(); + let sparse_index_cache = new_cache_for_test(); + let blockfile_provider = ArrowBlockfileProvider::new( + storage, + TEST_MAX_BLOCK_SIZE_BYTES, + block_cache, + sparse_index_cache, + ); + + // Test no keys + let writer = blockfile_provider + .write::<&str, Vec>(BlockfileWriterOptions::default()) + .await + .unwrap(); + + let flusher = writer.commit::<&str, Vec>().await.unwrap(); + assert_eq!(0_u64, flusher.count()); + flusher.flush::<&str, Vec>().await.unwrap(); + + // Test 2 keys + let writer = blockfile_provider + .write::<&str, Vec>(BlockfileWriterOptions::default()) + .await + .unwrap(); + + let prefix_1 = "key"; + let key1 = "zzzz"; + let value1 = vec![1, 2, 3]; + writer.set(prefix_1, key1, value1.clone()).await.unwrap(); + + let prefix_2 = "key"; + let key2 = "aaaa"; + let value2 = vec![4, 5, 6]; + writer.set(prefix_2, key2, value2).await.unwrap(); + + let flusher1 = writer.commit::<&str, Vec>().await.unwrap(); + assert_eq!(2_u64, flusher1.count()); + + // Test add keys after commit, before flush + let writer = blockfile_provider + .write::<&str, Vec>(BlockfileWriterOptions::default()) + .await + .unwrap(); + + let prefix_3 = "key"; + let key3 = "yyyy"; + let value3 = vec![7, 8, 9]; + writer.set(prefix_3, key3, value3.clone()).await.unwrap(); + + let prefix_4 = "key"; + let key4 = "bbbb"; + let value4 = vec![10, 11, 12]; + writer.set(prefix_4, key4, value4).await.unwrap(); + + let flusher2 = writer.commit::<&str, Vec>().await.unwrap(); + assert_eq!(2_u64, flusher2.count()); + + flusher1.flush::<&str, Vec>().await.unwrap(); + flusher2.flush::<&str, Vec>().await.unwrap(); + + // Test count after flush + let writer = blockfile_provider + .write::<&str, Vec>(BlockfileWriterOptions::default()) + .await + .unwrap(); + let flusher = writer.commit::<&str, Vec>().await.unwrap(); + assert_eq!(0_u64, flusher.count()); + } + fn test_prefix(num_keys: u32, prefix_for_query: u32) { Runtime::new().unwrap().block_on(async { let tmp_dir = tempfile::tempdir().unwrap(); diff --git a/rust/blockstore/src/arrow/flusher.rs b/rust/blockstore/src/arrow/flusher.rs index 401d3ae4f7d..02726eb9f5c 100644 --- a/rust/blockstore/src/arrow/flusher.rs +++ b/rust/blockstore/src/arrow/flusher.rs @@ -14,6 +14,7 @@ pub struct ArrowBlockfileFlusher { blocks: Vec, root: RootWriter, id: Uuid, + count: u64, } impl ArrowBlockfileFlusher { @@ -23,6 +24,7 @@ impl ArrowBlockfileFlusher { blocks: Vec, root: RootWriter, id: Uuid, + count: u64, ) -> Self { Self { block_manager, @@ -30,6 +32,7 @@ impl ArrowBlockfileFlusher { blocks, root, id, + count, } } @@ -68,4 +71,8 @@ impl ArrowBlockfileFlusher { pub(crate) fn id(&self) -> Uuid { self.id } + + pub(crate) fn count(&self) -> u64 { + self.count + } } diff --git a/rust/blockstore/src/arrow/ordered_blockfile_writer.rs b/rust/blockstore/src/arrow/ordered_blockfile_writer.rs index 5fa021eefc7..6806f379b99 100644 --- a/rust/blockstore/src/arrow/ordered_blockfile_writer.rs +++ b/rust/blockstore/src/arrow/ordered_blockfile_writer.rs @@ -194,12 +194,23 @@ impl ArrowOrderedBlockfileWriter { Box::new(ArrowBlockfileError::MigrationError(e)) as Box })?; + let count = self + .root + .sparse_index + .data + .lock() + .counts + .values() + .map(|&x| x as u64) + .sum::(); + let flusher = ArrowBlockfileFlusher::new( self.block_manager, self.root_manager, blocks, self.root, self.id, + count, ); Ok(flusher) @@ -366,7 +377,7 @@ mod tests { use uuid::Uuid; #[tokio::test] - async fn test_count() { + async fn test_reader_count() { let tmp_dir = tempfile::tempdir().unwrap(); let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap())); let block_cache = new_cache_for_test(); @@ -405,6 +416,79 @@ mod tests { } } + #[tokio::test] + async fn test_writer_count() { + let tmp_dir = tempfile::tempdir().unwrap(); + let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap())); + let block_cache = new_cache_for_test(); + let sparse_index_cache = new_cache_for_test(); + let blockfile_provider = ArrowBlockfileProvider::new( + storage, + TEST_MAX_BLOCK_SIZE_BYTES, + block_cache, + sparse_index_cache, + ); + + // Test no keys + let writer = blockfile_provider + .write::<&str, Vec>(BlockfileWriterOptions::default()) + .await + .unwrap(); + + let flusher = writer.commit::<&str, Vec>().await.unwrap(); + assert_eq!(0_u64, flusher.count()); + flusher.flush::<&str, Vec>().await.unwrap(); + + // Test 2 keys + let writer = blockfile_provider + .write::<&str, Vec>(BlockfileWriterOptions::default()) + .await + .unwrap(); + + let prefix_1 = "key"; + let key1 = "zzzz"; + let value1 = vec![1, 2, 3]; + writer.set(prefix_1, key1, value1.clone()).await.unwrap(); + + let prefix_2 = "key"; + let key2 = "aaaa"; + let value2 = vec![4, 5, 6]; + writer.set(prefix_2, key2, value2).await.unwrap(); + + let flusher1 = writer.commit::<&str, Vec>().await.unwrap(); + assert_eq!(2_u64, flusher1.count()); + + // Test add keys after commit, before flush + let writer = blockfile_provider + .write::<&str, Vec>(BlockfileWriterOptions::default()) + .await + .unwrap(); + + let prefix_3 = "key"; + let key3 = "yyyy"; + let value3 = vec![7, 8, 9]; + writer.set(prefix_3, key3, value3.clone()).await.unwrap(); + + let prefix_4 = "key"; + let key4 = "bbbb"; + let value4 = vec![10, 11, 12]; + writer.set(prefix_4, key4, value4).await.unwrap(); + + let flusher2 = writer.commit::<&str, Vec>().await.unwrap(); + assert_eq!(2_u64, flusher2.count()); + + flusher1.flush::<&str, Vec>().await.unwrap(); + flusher2.flush::<&str, Vec>().await.unwrap(); + + // Test count after flush + let writer = blockfile_provider + .write::<&str, Vec>(BlockfileWriterOptions::default()) + .await + .unwrap(); + let flusher = writer.commit::<&str, Vec>().await.unwrap(); + assert_eq!(0_u64, flusher.count()); + } + #[tokio::test] async fn test_blockfile() { let tmp_dir = tempfile::tempdir().unwrap(); diff --git a/rust/blockstore/src/types/flusher.rs b/rust/blockstore/src/types/flusher.rs index 12d9ee76788..8364c382d2d 100644 --- a/rust/blockstore/src/types/flusher.rs +++ b/rust/blockstore/src/types/flusher.rs @@ -31,4 +31,11 @@ impl BlockfileFlusher { BlockfileFlusher::ArrowBlockfileFlusher(flusher) => flusher.id(), } } + + pub fn count(&self) -> u64 { + match self { + BlockfileFlusher::MemoryBlockfileFlusher(_) => unimplemented!(), // no op + BlockfileFlusher::ArrowBlockfileFlusher(flusher) => flusher.count(), + } + } } diff --git a/rust/sysdb/src/sysdb.rs b/rust/sysdb/src/sysdb.rs index e72574cbd80..db720c79a78 100644 --- a/rust/sysdb/src/sysdb.rs +++ b/rust/sysdb/src/sysdb.rs @@ -76,6 +76,7 @@ impl SysDb { log_position: i64, collection_version: i32, segment_flush_info: Arc<[SegmentFlushInfo]>, + total_records_post_compaction: u64, ) -> Result { match self { SysDb::Grpc(grpc) => { @@ -85,6 +86,7 @@ impl SysDb { log_position, collection_version, segment_flush_info, + total_records_post_compaction, ) .await } @@ -95,6 +97,7 @@ impl SysDb { log_position, collection_version, segment_flush_info, + total_records_post_compaction, ) .await } @@ -273,6 +276,7 @@ impl GrpcSysDb { log_position: i64, collection_version: i32, segment_flush_info: Arc<[SegmentFlushInfo]>, + total_records_post_compaction: u64, ) -> Result { let segment_compaction_info = segment_flush_info @@ -296,6 +300,7 @@ impl GrpcSysDb { log_position, collection_version, segment_compaction_info, + total_records_post_compaction, }; let res = self.client.flush_collection_compaction(req).await; diff --git a/rust/sysdb/src/test_sysdb.rs b/rust/sysdb/src/test_sysdb.rs index 0af96b4a4ec..ec601148c91 100644 --- a/rust/sysdb/src/test_sysdb.rs +++ b/rust/sysdb/src/test_sysdb.rs @@ -171,6 +171,7 @@ impl TestSysDb { log_position: i64, collection_version: i32, segment_flush_info: Arc<[SegmentFlushInfo]>, + total_records_post_compaction: u64, ) -> Result { let mut inner = self.inner.lock(); let collection = inner.collections.get(&collection_id); @@ -182,6 +183,7 @@ impl TestSysDb { collection.log_position = log_position; let new_collection_version = collection_version + 1; collection.version = new_collection_version; + collection.total_records_post_compaction = total_records_post_compaction; inner .collections .insert(collection.collection_id, collection); diff --git a/rust/worker/src/execution/operators/register.rs b/rust/worker/src/execution/operators/register.rs index 0c56f94111d..60eb01e3e15 100644 --- a/rust/worker/src/execution/operators/register.rs +++ b/rust/worker/src/execution/operators/register.rs @@ -34,6 +34,7 @@ impl RegisterOperator { /// collection version in sysdb is not the same as the current collection version, the flush /// operation will fail. /// * `segment_flush_info` - The segment flush info. +/// * `total_records_post_compaction` - The total number of records in the collection post compaction. /// * `sysdb` - The sysdb client. /// * `log` - The log client. pub struct RegisterInput { @@ -42,11 +43,13 @@ pub struct RegisterInput { log_position: i64, collection_version: i32, segment_flush_info: Arc<[SegmentFlushInfo]>, + total_records_post_compaction: u64, sysdb: Box, log: Box, } impl RegisterInput { + #[allow(clippy::too_many_arguments)] /// Create a new flush sysdb input. pub fn new( tenant: String, @@ -54,6 +57,7 @@ impl RegisterInput { log_position: i64, collection_version: i32, segment_flush_info: Arc<[SegmentFlushInfo]>, + total_records_post_compaction: u64, sysdb: Box, log: Box, ) -> Self { @@ -63,6 +67,7 @@ impl RegisterInput { log_position, collection_version, segment_flush_info, + total_records_post_compaction, sysdb, log, } @@ -112,6 +117,7 @@ impl Operator for RegisterOperator { input.log_position, input.collection_version, input.segment_flush_info.clone(), + input.total_records_post_compaction, ) .await; @@ -153,6 +159,7 @@ mod tests { let collection_uuid_1 = CollectionUuid::from_str("00000000-0000-0000-0000-000000000001").unwrap(); let tenant_1 = "tenant_1".to_string(); + let total_records_post_compaction: u64 = 5; let collection_1 = Collection { collection_id: collection_uuid_1, name: "collection_1".to_string(), @@ -162,7 +169,7 @@ mod tests { database: "database_1".to_string(), log_position: 0, version: collection_version, - total_records_post_compaction: 0, + total_records_post_compaction, }; let collection_uuid_2 = @@ -177,7 +184,7 @@ mod tests { database: "database_2".to_string(), log_position: 0, version: collection_version, - total_records_post_compaction: 0, + total_records_post_compaction, }; match *sysdb { @@ -244,6 +251,7 @@ mod tests { log_position, collection_version, segment_flush_info.into(), + total_records_post_compaction, sysdb.clone(), log.clone(), ); @@ -270,6 +278,10 @@ mod tests { assert_eq!(collection.len(), 1); let collection = collection[0].clone(); assert_eq!(collection.log_position, log_position); + assert_eq!( + collection.total_records_post_compaction, + total_records_post_compaction + ); let collection_1_segments = sysdb .get_segments(None, None, None, collection_uuid_1) diff --git a/rust/worker/src/execution/orchestration/compact.rs b/rust/worker/src/execution/orchestration/compact.rs index b2659463e6b..3370be7570d 100644 --- a/rust/worker/src/execution/orchestration/compact.rs +++ b/rust/worker/src/execution/orchestration/compact.rs @@ -131,6 +131,8 @@ pub struct CompactOrchestrator { flush_results: Vec, // We track a parent span for each segment type so we can group all the spans for a given segment type (makes the resulting trace much easier to read) segment_spans: HashMap, + // Total number of records in the collection after the compaction + total_records_last_compaction: u64, } #[derive(Error, Debug)] @@ -253,6 +255,7 @@ impl CompactOrchestrator { writers: OnceCell::new(), flush_results: Vec::new(), segment_spans: HashMap::new(), + total_records_last_compaction: 0, } } @@ -471,6 +474,7 @@ impl CompactOrchestrator { log_position, self.compaction_job.collection_version, self.flush_results.clone().into(), + self.total_records_last_compaction, self.sysdb.clone(), self.log.clone(), ); @@ -840,7 +844,13 @@ impl Handler return, }; - self.dispatch_segment_flush(message.flusher, ctx.receiver(), ctx) + let flusher = message.flusher; + // If the flusher recieved is a record segment flusher, get the number of keys for the blockfile and set it on the orchestrator + if let ChromaSegmentFlusher::RecordSegment(ref record_segment_flusher) = flusher { + self.total_records_last_compaction = record_segment_flusher.count(); + } + + self.dispatch_segment_flush(flusher, ctx.receiver(), ctx) .await; } } diff --git a/rust/worker/src/segment/metadata_segment.rs b/rust/worker/src/segment/metadata_segment.rs index 72eaff4444c..c77a4152f13 100644 --- a/rust/worker/src/segment/metadata_segment.rs +++ b/rust/worker/src/segment/metadata_segment.rs @@ -1387,6 +1387,8 @@ mod test { .commit() .await .expect("Commit for segment writer failed"); + let count = record_flusher.count(); + assert_eq!(count, 2_u64); let metadata_flusher = metadata_writer .commit() .await diff --git a/rust/worker/src/segment/record_segment.rs b/rust/worker/src/segment/record_segment.rs index 4d28487e47c..83fe6a3fb15 100644 --- a/rust/worker/src/segment/record_segment.rs +++ b/rust/worker/src/segment/record_segment.rs @@ -632,6 +632,10 @@ impl RecordSegmentFlusher { Ok(flushed_files) } + + pub(crate) fn count(&self) -> u64 { + self.id_to_user_id_flusher.count() + } } #[derive(Clone)]