Skip to content

Commit

Permalink
[BUG] Log service incorrectly doesn't hydrate collection id (#1922)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
- Log service incorrectly did not return the collection ID of the
record. I suspect that we meant to strip the collection id and rehydrate
it into the proto for space reasons. I honored this intent and rehydrate
the colleciton id.
	 - Add some readme for running tests locally with PG
- Fixed a bug in the log_service test where the input is mutated, which
makes the source of truth have no collection id. This was passing when
we incorrectly returned no collection id but was correctly failing now.
I patched the test by cloning the records for a SOT.
	 - For the test I fixed, the expected vs actual order was incorrect.
 - New functionality
	 - None

## Test plan
*How are these changes tested?*
- [x] Tests pass locally with `pytest` for python, `yarn test` for js,
`cargo test` for rust

## Documentation Changes
None
  • Loading branch information
HammadB authored Mar 24, 2024
1 parent 3a48455 commit ba7b52e
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 9 deletions.
4 changes: 4 additions & 0 deletions go/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
- postgres=# `create role chroma with login password 'chroma';`
- postgres=# `alter role chroma with superuser;`
- postgres=# `create database chroma;`
- Set postgres ENV Vars
Several tests (such as record_log_service_test.go) require the following environment variables to be set:
- `export POSTGRES_HOST=localhost`
- `export POSTGRES_PORT=5432`
- Atlas schema migration
- [~/chroma/go]: `atlas migrate diff --env dev`
- [~/chroma/go]: `atlas --env dev migrate apply --url "postgres://chroma:chroma@localhost:5432/chroma?sslmode=disable"`
4 changes: 4 additions & 0 deletions go/pkg/logservice/grpc/record_log_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ func (s *Server) PushLogs(ctx context.Context, req *logservicepb.PushLogsRequest
}
var recordsContent [][]byte
for _, record := range req.Records {
// We remove the collection id for space reasons, as its double stored in the wrapping database RecordLog object.
// PullLogs will rehydrate the collection id from the database.
record.CollectionId = ""
data, err := proto.Marshal(record)
if err != nil {
Expand Down Expand Up @@ -73,6 +75,8 @@ func (s *Server) PullLogs(ctx context.Context, req *logservicepb.PullLogsRequest
}
return nil, grpcError
}
// Here we rehydrate the collection id from the database since in PushLogs we removed it for space reasons.
record.CollectionId = *recordLogs[index].CollectionID
recordLog := &logservicepb.RecordLog{
LogId: recordLogs[index].ID,
Record: record,
Expand Down
24 changes: 15 additions & 9 deletions go/pkg/logservice/grpc/record_log_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ import (
"bytes"
"context"
"encoding/binary"
"testing"
"time"

"github.com/chroma-core/chroma/go/pkg/logservice/testutils"
"github.com/chroma-core/chroma/go/pkg/metastore/db/dbcore"
"github.com/chroma-core/chroma/go/pkg/metastore/db/dbmodel"
Expand All @@ -16,8 +19,6 @@ import (
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"gorm.io/gorm"
"testing"
"time"
)

type RecordLogServiceTestSuite struct {
Expand Down Expand Up @@ -132,6 +133,11 @@ func (suite *RecordLogServiceTestSuite) TestServer_PushLogs() {
func (suite *RecordLogServiceTestSuite) TestServer_PullLogs() {
// push some records
recordsToSubmit := GetTestEmbeddingRecords(suite.collectionId.String())
// deep clone the records since PushLogs will mutate the records and we need a source of truth
recordsToSubmit_sot := make([]*coordinatorpb.SubmitEmbeddingRecord, len(recordsToSubmit))
for i := range recordsToSubmit {
recordsToSubmit_sot[i] = proto.Clone(recordsToSubmit[i]).(*coordinatorpb.SubmitEmbeddingRecord)
}
pushRequest := logservicepb.PushLogsRequest{
CollectionId: suite.collectionId.String(),
Records: recordsToSubmit,
Expand All @@ -150,13 +156,13 @@ func (suite *RecordLogServiceTestSuite) TestServer_PullLogs() {
suite.Len(pullResponse.Records, 3)
for index := range pullResponse.Records {
suite.Equal(int64(index+1), pullResponse.Records[index].LogId)
suite.Equal(pullResponse.Records[index].Record.Id, recordsToSubmit[index].Id)
suite.Equal(pullResponse.Records[index].Record.Operation, recordsToSubmit[index].Operation)
suite.Equal(pullResponse.Records[index].Record.CollectionId, recordsToSubmit[index].CollectionId)
suite.Equal(pullResponse.Records[index].Record.Metadata, recordsToSubmit[index].Metadata)
suite.Equal(pullResponse.Records[index].Record.Vector.Dimension, recordsToSubmit[index].Vector.Dimension)
suite.Equal(pullResponse.Records[index].Record.Vector.Encoding, recordsToSubmit[index].Vector.Encoding)
suite.Equal(pullResponse.Records[index].Record.Vector.Vector, recordsToSubmit[index].Vector.Vector)
suite.Equal(recordsToSubmit_sot[index].Id, pullResponse.Records[index].Record.Id)
suite.Equal(recordsToSubmit_sot[index].Operation, pullResponse.Records[index].Record.Operation)
suite.Equal(recordsToSubmit_sot[index].CollectionId, pullResponse.Records[index].Record.CollectionId)
suite.Equal(recordsToSubmit_sot[index].Metadata, pullResponse.Records[index].Record.Metadata)
suite.Equal(recordsToSubmit_sot[index].Vector.Dimension, pullResponse.Records[index].Record.Vector.Dimension)
suite.Equal(recordsToSubmit_sot[index].Vector.Encoding, pullResponse.Records[index].Record.Vector.Encoding)
suite.Equal(recordsToSubmit_sot[index].Vector.Vector, pullResponse.Records[index].Record.Vector.Vector)
}
}

Expand Down

0 comments on commit ba7b52e

Please sign in to comment.