Skip to content

Commit

Permalink
[ENH] Update compactor to flush total records upon compaction (#3483)
Browse files Browse the repository at this point in the history
## Description of changes

Updates the compactor to calculate the total records per collection and
flush to the sysdb upon every compaction.

*Summarize the changes made by this PR.*
 - New functionality
- `FlushCollectionCompaction` struct includes
`TotalRecordsPostCompaction`
- `SysDB` populates the `total_records_post_compaction` column when
receiving a flush
   - `ArrowBlockfileFlusher` contains a new attribute, `total_keys`
- `ArrowUnorderedBlockfileWriter` sums the total keys using the
`SparseIndexWriter` and returns an `ArrowBlockfileFlusher` with the
summed count
- `RegisterInput` contains a new attribute,
`total_records_post_compaction `
- If `CompactOrchestrator`, when handling `CommitSegmentWriterOutput`,
receives a `ChromaSegmentFlusher::RecordSegment`, it reads
`total_keys()` and sets it as an attribute on itself.
- `ChromaSegmentFlusher::RecordSegment` has `total_keys()` through its
`ArrowBlockfileFlusher`
- `CompactOrchestrator` sends its `num_records_last_compaction` value to
a `RegisterInput` to be flushed to the `SysDB`

## Test plan
*How are these changes tested?*

- [x] Tests pass locally with `pytest` for python, `yarn test` for js,
`cargo test` for rust
- [x] Tested locally and confirmed that compaction correctly updates the
column in the `SysDB`

## Documentation Changes
*Are all docstrings for user-facing APIs updated if required? Do we need
to make documentation changes in the [docs
repository](https://github.com/chroma-core/docs)?*
  • Loading branch information
drewkim authored Jan 16, 2025
1 parent 6932c73 commit 537f29a
Show file tree
Hide file tree
Showing 19 changed files with 286 additions and 35 deletions.
11 changes: 6 additions & 5 deletions go/pkg/sysdb/coordinator/model/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,12 @@ type UpdateCollection struct {
}

type FlushCollectionCompaction struct {
ID types.UniqueID
TenantID string
LogPosition int64
CurrentCollectionVersion int32
FlushSegmentCompactions []*FlushSegmentCompaction
ID types.UniqueID
TenantID string
LogPosition int64
CurrentCollectionVersion int32
FlushSegmentCompactions []*FlushSegmentCompaction
TotalRecordsPostCompaction uint64
}

type FlushCollectionInfo struct {
Expand Down
2 changes: 1 addition & 1 deletion go/pkg/sysdb/coordinator/table_catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,7 @@ func (tc *Catalog) FlushCollectionCompaction(ctx context.Context, flushCollectio
}

// update collection log position and version
collectionVersion, err := tc.metaDomain.CollectionDb(txCtx).UpdateLogPositionAndVersion(flushCollectionCompaction.ID.String(), flushCollectionCompaction.LogPosition, flushCollectionCompaction.CurrentCollectionVersion)
collectionVersion, err := tc.metaDomain.CollectionDb(txCtx).UpdateLogPositionVersionAndTotalRecords(flushCollectionCompaction.ID.String(), flushCollectionCompaction.LogPosition, flushCollectionCompaction.CurrentCollectionVersion, flushCollectionCompaction.TotalRecordsPostCompaction)
if err != nil {
return err
}
Expand Down
11 changes: 6 additions & 5 deletions go/pkg/sysdb/grpc/collection_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,11 +325,12 @@ func (s *Server) FlushCollectionCompaction(ctx context.Context, req *coordinator
})
}
FlushCollectionCompaction := &model.FlushCollectionCompaction{
ID: collectionID,
TenantID: req.TenantId,
LogPosition: req.LogPosition,
CurrentCollectionVersion: req.CollectionVersion,
FlushSegmentCompactions: segmentCompactionInfo,
ID: collectionID,
TenantID: req.TenantId,
LogPosition: req.LogPosition,
CurrentCollectionVersion: req.CollectionVersion,
FlushSegmentCompactions: segmentCompactionInfo,
TotalRecordsPostCompaction: req.TotalRecordsPostCompaction,
}
flushCollectionInfo, err := s.coordinator.FlushCollectionCompaction(ctx, FlushCollectionCompaction)
if err != nil {
Expand Down
6 changes: 3 additions & 3 deletions go/pkg/sysdb/metastore/db/dao/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,8 @@ func (s *collectionDb) Update(in *dbmodel.Collection) error {
return nil
}

func (s *collectionDb) UpdateLogPositionAndVersion(collectionID string, logPosition int64, currentCollectionVersion int32) (int32, error) {
log.Info("update log position and version", zap.String("collectionID", collectionID), zap.Int64("logPosition", logPosition), zap.Int32("currentCollectionVersion", currentCollectionVersion))
func (s *collectionDb) UpdateLogPositionVersionAndTotalRecords(collectionID string, logPosition int64, currentCollectionVersion int32, totalRecordsPostCompaction uint64) (int32, error) {
log.Info("update log position, version, and total records post compaction", zap.String("collectionID", collectionID), zap.Int64("logPosition", logPosition), zap.Int32("currentCollectionVersion", currentCollectionVersion), zap.Uint64("totalRecords", totalRecordsPostCompaction))
var collection dbmodel.Collection
// We use select for update to ensure no lost update happens even for isolation level read committed or below
// https://patrick.engineering/posts/postgres-internals/
Expand All @@ -232,7 +232,7 @@ func (s *collectionDb) UpdateLogPositionAndVersion(collectionID string, logPosit
}

version := currentCollectionVersion + 1
err = s.db.Model(&dbmodel.Collection{}).Where("id = ?", collectionID).Updates(map[string]interface{}{"log_position": logPosition, "version": version}).Error
err = s.db.Model(&dbmodel.Collection{}).Where("id = ?", collectionID).Updates(map[string]interface{}{"log_position": logPosition, "version": version, "total_records_post_compaction": totalRecordsPostCompaction}).Error
if err != nil {
return 0, err
}
Expand Down
11 changes: 6 additions & 5 deletions go/pkg/sysdb/metastore/db/dao/collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func (suite *CollectionDbTestSuite) TestCollectionDb_GetCollections() {
suite.NoError(err)
}

func (suite *CollectionDbTestSuite) TestCollectionDb_UpdateLogPositionAndVersion() {
func (suite *CollectionDbTestSuite) TestCollectionDb_UpdateLogPositionVersionAndTotalRecords() {
collectionName := "test_collection_get_collections"
collectionID, _ := CreateTestCollection(suite.db, collectionName, 128, suite.databaseId)
// verify default values
Expand All @@ -132,22 +132,23 @@ func (suite *CollectionDbTestSuite) TestCollectionDb_UpdateLogPositionAndVersion
suite.Equal(int32(0), collections[0].Collection.Version)

// update log position and version
version, err := suite.collectionDb.UpdateLogPositionAndVersion(collectionID, int64(10), 0)
version, err := suite.collectionDb.UpdateLogPositionVersionAndTotalRecords(collectionID, int64(10), 0, uint64(100))
suite.NoError(err)
suite.Equal(int32(1), version)
collections, _ = suite.collectionDb.GetCollections(&collectionID, nil, "", "", nil, nil)
suite.Len(collections, 1)
suite.Equal(int64(10), collections[0].Collection.LogPosition)
suite.Equal(int32(1), collections[0].Collection.Version)
suite.Equal(uint64(100), collections[0].Collection.TotalRecordsPostCompaction)

// invalid log position
_, err = suite.collectionDb.UpdateLogPositionAndVersion(collectionID, int64(5), 0)
_, err = suite.collectionDb.UpdateLogPositionVersionAndTotalRecords(collectionID, int64(5), 0, uint64(100))
suite.Error(err, "collection log position Stale")

// invalid version
_, err = suite.collectionDb.UpdateLogPositionAndVersion(collectionID, int64(20), 0)
_, err = suite.collectionDb.UpdateLogPositionVersionAndTotalRecords(collectionID, int64(20), 0, uint64(100))
suite.Error(err, "collection version invalid")
_, err = suite.collectionDb.UpdateLogPositionAndVersion(collectionID, int64(20), 3)
_, err = suite.collectionDb.UpdateLogPositionVersionAndTotalRecords(collectionID, int64(20), 3, uint64(100))
suite.Error(err, "collection version invalid")

//clean up
Expand Down
2 changes: 1 addition & 1 deletion go/pkg/sysdb/metastore/db/dbmodel/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,6 @@ type ICollectionDb interface {
Insert(in *Collection) error
Update(in *Collection) error
DeleteAll() error
UpdateLogPositionAndVersion(collectionID string, logPosition int64, currentCollectionVersion int32) (int32, error)
UpdateLogPositionVersionAndTotalRecords(collectionID string, logPosition int64, currentCollectionVersion int32, totalRecordsPostCompaction uint64) (int32, error)
GetCollectionEntry(collectionID *string, databaseName *string) (*Collection, error)
}
20 changes: 10 additions & 10 deletions go/pkg/sysdb/metastore/db/dbmodel/mocks/ICollectionDb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 30 additions & 0 deletions go/pkg/sysdb/metastore/db/dbmodel/mocks/ISegmentDb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions idl/chromadb/proto/coordinator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ message FlushCollectionCompactionRequest {
int64 log_position = 3;
int32 collection_version = 4;
repeated FlushSegmentCompactionInfo segment_compaction_info = 5;
uint64 total_records_post_compaction = 6;
}

message FlushCollectionCompactionResponse {
Expand Down
86 changes: 85 additions & 1 deletion rust/blockstore/src/arrow/blockfile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,23 @@ impl ArrowUnorderedBlockfileWriter {
Box::new(ArrowBlockfileError::MigrationError(e)) as Box<dyn ChromaError>
})?;

let count = self
.root
.sparse_index
.data
.lock()
.counts
.values()
.map(|&x| x as u64)
.sum::<u64>();

let flusher = ArrowBlockfileFlusher::new(
self.block_manager,
self.root_manager,
blocks,
self.root,
self.id,
count,
);

Ok(flusher)
Expand Down Expand Up @@ -705,7 +716,7 @@ mod tests {
use uuid::Uuid;

#[tokio::test]
async fn test_count() {
async fn test_reader_count() {
let tmp_dir = tempfile::tempdir().unwrap();
let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap()));
let block_cache = new_cache_for_test();
Expand Down Expand Up @@ -744,6 +755,79 @@ mod tests {
}
}

#[tokio::test]
async fn test_writer_count() {
let tmp_dir = tempfile::tempdir().unwrap();
let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap()));
let block_cache = new_cache_for_test();
let sparse_index_cache = new_cache_for_test();
let blockfile_provider = ArrowBlockfileProvider::new(
storage,
TEST_MAX_BLOCK_SIZE_BYTES,
block_cache,
sparse_index_cache,
);

// Test no keys
let writer = blockfile_provider
.write::<&str, Vec<u32>>(BlockfileWriterOptions::default())
.await
.unwrap();

let flusher = writer.commit::<&str, Vec<u32>>().await.unwrap();
assert_eq!(0_u64, flusher.count());
flusher.flush::<&str, Vec<u32>>().await.unwrap();

// Test 2 keys
let writer = blockfile_provider
.write::<&str, Vec<u32>>(BlockfileWriterOptions::default())
.await
.unwrap();

let prefix_1 = "key";
let key1 = "zzzz";
let value1 = vec![1, 2, 3];
writer.set(prefix_1, key1, value1.clone()).await.unwrap();

let prefix_2 = "key";
let key2 = "aaaa";
let value2 = vec![4, 5, 6];
writer.set(prefix_2, key2, value2).await.unwrap();

let flusher1 = writer.commit::<&str, Vec<u32>>().await.unwrap();
assert_eq!(2_u64, flusher1.count());

// Test add keys after commit, before flush
let writer = blockfile_provider
.write::<&str, Vec<u32>>(BlockfileWriterOptions::default())
.await
.unwrap();

let prefix_3 = "key";
let key3 = "yyyy";
let value3 = vec![7, 8, 9];
writer.set(prefix_3, key3, value3.clone()).await.unwrap();

let prefix_4 = "key";
let key4 = "bbbb";
let value4 = vec![10, 11, 12];
writer.set(prefix_4, key4, value4).await.unwrap();

let flusher2 = writer.commit::<&str, Vec<u32>>().await.unwrap();
assert_eq!(2_u64, flusher2.count());

flusher1.flush::<&str, Vec<u32>>().await.unwrap();
flusher2.flush::<&str, Vec<u32>>().await.unwrap();

// Test count after flush
let writer = blockfile_provider
.write::<&str, Vec<u32>>(BlockfileWriterOptions::default())
.await
.unwrap();
let flusher = writer.commit::<&str, Vec<u32>>().await.unwrap();
assert_eq!(0_u64, flusher.count());
}

fn test_prefix(num_keys: u32, prefix_for_query: u32) {
Runtime::new().unwrap().block_on(async {
let tmp_dir = tempfile::tempdir().unwrap();
Expand Down
7 changes: 7 additions & 0 deletions rust/blockstore/src/arrow/flusher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub struct ArrowBlockfileFlusher {
blocks: Vec<Block>,
root: RootWriter,
id: Uuid,
count: u64,
}

impl ArrowBlockfileFlusher {
Expand All @@ -23,13 +24,15 @@ impl ArrowBlockfileFlusher {
blocks: Vec<Block>,
root: RootWriter,
id: Uuid,
count: u64,
) -> Self {
Self {
block_manager,
root_manager,
blocks,
root,
id,
count,
}
}

Expand Down Expand Up @@ -68,4 +71,8 @@ impl ArrowBlockfileFlusher {
pub(crate) fn id(&self) -> Uuid {
self.id
}

pub(crate) fn count(&self) -> u64 {
self.count
}
}
Loading

0 comments on commit 537f29a

Please sign in to comment.