Skip to content

Commit

Permalink
[BUG] Move getOrCreate logic down to system catalog (#1352)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
	 - This fixes #1284
 - New functionality
	 - ...

## Test plan
*How are these changes tested?*

- [ ] test_system.py

## Documentation Changes
*Are all docstrings for user-facing APIs updated if required? Do we need
to make documentation changes in the [docs
repository](https://github.com/chroma-core/docs)?*
  • Loading branch information
Ishiihara authored Nov 8, 2023
1 parent 933d46d commit 233a7cc
Show file tree
Hide file tree
Showing 9 changed files with 236 additions and 221 deletions.
14 changes: 4 additions & 10 deletions go/coordinator/internal/coordinator/apis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ func SampleCollections(t *testing.T, tenantID string, databaseName string) []*mo
Topic: "test_topic_1",
Metadata: metadata1,
Dimension: &dimension,
Created: true,
TenantID: tenantID,
DatabaseName: databaseName,
},
Expand All @@ -185,6 +186,7 @@ func SampleCollections(t *testing.T, tenantID string, databaseName string) []*mo
Topic: "test_topic_2",
Metadata: metadata2,
Dimension: nil,
Created: true,
TenantID: tenantID,
DatabaseName: databaseName,
},
Expand All @@ -194,6 +196,7 @@ func SampleCollections(t *testing.T, tenantID string, databaseName string) []*mo
Topic: "test_topic_3",
Metadata: metadata3,
Dimension: nil,
Created: true,
TenantID: tenantID,
DatabaseName: databaseName,
},
Expand Down Expand Up @@ -343,6 +346,7 @@ func TestUpdateCollections(t *testing.T) {
Topic: sampleCollections[0].Topic,
Metadata: sampleCollections[0].Metadata,
Dimension: sampleCollections[0].Dimension,
Created: false,
TenantID: sampleCollections[0].TenantID,
DatabaseName: sampleCollections[0].DatabaseName,
}
Expand Down Expand Up @@ -463,16 +467,6 @@ func TestCreateUpdateWithDatabase(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, 1, len(result))
assert.Equal(t, "new_name_0", result[0].Name)
// # Try to create the collection in the default database in the new database and expect an error
// with pytest.raises(UniqueConstraintError):
// sysdb.create_collection(
// id=sample_collections[1]["id"],
// name=sample_collections[1]["name"],
// metadata=sample_collections[1]["metadata"],
// dimension=sample_collections[1]["dimension"],
// database="new_database",
// )
//
}

func TestGetMultipleWithDatabase(t *testing.T) {
Expand Down
1 change: 0 additions & 1 deletion go/coordinator/internal/coordinator/meta.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,6 @@ func (mt *MetaTable) GetCollections(ctx context.Context, collectionID types.Uniq
}
collections := make([]*model.Collection, 0, len(mt.tenantDatabaseCollectionCache[tenantID][databaseName]))
for _, collection := range mt.tenantDatabaseCollectionCache[tenantID][databaseName] {
log.Error("collection", zap.Any("collection", collection))
if model.FilterCollection(collection, collectionID, collectionName, collectionTopic) {
collections = append(collections, collection)
}
Expand Down
89 changes: 6 additions & 83 deletions go/coordinator/internal/grpccoordinator/collection_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,6 @@ func (s *Server) ResetState(context.Context, *emptypb.Empty) (*coordinatorpb.Chr
return res, nil
}

func (s *Server) CreateCollection(ctx context.Context, req *coordinatorpb.CreateCollectionRequest) (*coordinatorpb.CreateCollectionResponse, error) {
getOrCreate := req.GetGetOrCreate()
if getOrCreate {
return s.getOrCreateCollection(ctx, req)
} else {
return s.createCollection(ctx, req)
}
}

// Cases for get_or_create

// Case 0
Expand All @@ -58,79 +49,7 @@ func (s *Server) CreateCollection(ctx context.Context, req *coordinatorpb.Create

// The fact that we ignore the metadata of the generated collections is a
// bit weird, but it is the easiest way to excercise all cases
func (s *Server) getOrCreateCollection(ctx context.Context, req *coordinatorpb.CreateCollectionRequest) (*coordinatorpb.CreateCollectionResponse, error) {
res := &coordinatorpb.CreateCollectionResponse{}
name := req.GetName()
tenantID := req.GetTenant()
databaseName := req.GetDatabase()
collections, err := s.coordinator.GetCollections(ctx, types.NilUniqueID(), &name, nil, tenantID, databaseName)
if err != nil {
log.Error("error getting collections", zap.Error(err))
res.Collection = &coordinatorpb.Collection{
Id: req.Id,
Name: req.Name,
Dimension: req.Dimension,
Metadata: req.Metadata,
}
res.Created = false
res.Status = failResponseWithError(err, errorCode)
return res, nil
}
if len(collections) > 0 { // collection exists, need to update the metadata
if req.Metadata != nil { // update existing collection with new metadata
metadata, err := convertCollectionMetadataToModel(req.Metadata)
if err != nil {
log.Error("error converting collection metadata to model", zap.Error(err))
res.Collection = &coordinatorpb.Collection{
Id: req.Id,
Name: req.Name,
Dimension: req.Dimension,
Metadata: req.Metadata,
}
res.Created = false
res.Status = failResponseWithError(err, errorCode)
return res, nil
}
// update collection with new metadata
updateCollection := &model.UpdateCollection{
ID: collections[0].ID,
Metadata: metadata,
}
updatedCollection, err := s.coordinator.UpdateCollection(ctx, updateCollection)
if err != nil {
log.Error("error updating collection", zap.Error(err))
res.Collection = &coordinatorpb.Collection{
Id: req.Id,
Name: req.Name,
Dimension: req.Dimension,
Metadata: req.Metadata,
}
res.Created = false
res.Status = failResponseWithError(err, errorCode)
return res, nil
}
// sucessfully update the metadata
res.Collection = convertCollectionToProto(updatedCollection)
res.Created = false
res.Status = setResponseStatus(successCode)
return res, nil
} else { // do nothing, return the existing collection
res.Collection = &coordinatorpb.Collection{
Id: req.Id,
Name: req.Name,
Dimension: req.Dimension,
}
res.Collection.Metadata = convertCollectionMetadataToProto(collections[0].Metadata)
res.Created = false
res.Status = setResponseStatus(successCode)
return res, nil
}
} else { // collection does not exist, need to create it
return s.createCollection(ctx, req)
}
}

func (s *Server) createCollection(ctx context.Context, req *coordinatorpb.CreateCollectionRequest) (*coordinatorpb.CreateCollectionResponse, error) {
func (s *Server) CreateCollection(ctx context.Context, req *coordinatorpb.CreateCollectionRequest) (*coordinatorpb.CreateCollectionResponse, error) {
res := &coordinatorpb.CreateCollectionResponse{}
createCollection, err := convertToCreateCollectionModel(req)
if err != nil {
Expand All @@ -140,6 +59,8 @@ func (s *Server) createCollection(ctx context.Context, req *coordinatorpb.Create
Name: req.Name,
Dimension: req.Dimension,
Metadata: req.Metadata,
Tenant: req.Tenant,
Database: req.Database,
}
res.Created = false
res.Status = failResponseWithError(err, successCode)
Expand All @@ -153,6 +74,8 @@ func (s *Server) createCollection(ctx context.Context, req *coordinatorpb.Create
Name: req.Name,
Dimension: req.Dimension,
Metadata: req.Metadata,
Tenant: req.Tenant,
Database: req.Database,
}
res.Created = false
if err == common.ErrCollectionUniqueConstraintViolation {
Expand All @@ -163,7 +86,7 @@ func (s *Server) createCollection(ctx context.Context, req *coordinatorpb.Create
return res, nil
}
res.Collection = convertCollectionToProto(collection)
res.Created = true
res.Created = collection.Created
res.Status = setResponseStatus(successCode)
return res, nil
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ func convertCollectionToProto(collection *model.Collection) *coordinatorpb.Colle
Name: collection.Name,
Topic: collection.Topic,
Dimension: collection.Dimension,
Tenant: collection.TenantID,
Database: collection.DatabaseName,
}
if collection.Metadata == nil {
return collectionpb
Expand Down Expand Up @@ -104,6 +106,7 @@ func convertToCreateCollectionModel(req *coordinatorpb.CreateCollectionRequest)
Name: req.Name,
Dimension: req.Dimension,
Metadata: metadata,
GetOrCreate: req.GetGetOrCreate(),
TenantID: req.GetTenant(),
DatabaseName: req.GetDatabase(),
}, nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,13 @@ func (mc *MemoryCatalog) CreateCollection(ctx context.Context, createCollection
log.Info("collection already exists", zap.Any("collection", collections[createCollection.ID]))
if createCollection.GetOrCreate {
if createCollection.Metadata != nil {
// For getOrCreate, update the metadata
collection.Metadata = createCollection.Metadata
}
return collection, nil
} else {
return nil, common.ErrCollectionUniqueConstraintViolation
}
return nil, common.ErrCollectionUniqueConstraintViolation
}
}
collection := &model.Collection{
Expand All @@ -175,6 +177,7 @@ func (mc *MemoryCatalog) CreateCollection(ctx context.Context, createCollection
Topic: createCollection.Topic,
Dimension: createCollection.Dimension,
Metadata: createCollection.Metadata,
Created: true,
TenantID: createCollection.TenantID,
DatabaseName: createCollection.DatabaseName,
}
Expand Down
38 changes: 34 additions & 4 deletions go/coordinator/internal/metastore/coordinator/table_catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ func (tc *Catalog) GetAllTenants(ctx context.Context, ts types.Timestamp) ([]*mo
}

func (tc *Catalog) CreateCollection(ctx context.Context, createCollection *model.CreateCollection, ts types.Timestamp) (*model.Collection, error) {
var ressult *model.Collection
var result *model.Collection

err := tc.txImpl.Transaction(ctx, func(txCtx context.Context) error {
// insert collection
Expand All @@ -213,6 +213,35 @@ func (tc *Catalog) CreateCollection(ctx context.Context, createCollection *model
return common.ErrDatabaseNotFound
}

collectionName := createCollection.Name
existing, err := tc.metaDomain.CollectionDb(txCtx).GetCollections(types.FromUniqueID(createCollection.ID), &collectionName, nil, tenantID, databaseName)
if err != nil {
log.Error("error getting collection", zap.Error(err))
return err
}
if len(existing) != 0 {
if createCollection.GetOrCreate {
collection := convertCollectionToModel(existing)[0]
if createCollection.Metadata != nil && !createCollection.Metadata.Equals(collection.Metadata) {
updatedCollection, err := tc.UpdateCollection(ctx, &model.UpdateCollection{
ID: collection.ID,
Metadata: createCollection.Metadata,
TenantID: tenantID,
DatabaseName: databaseName,
}, ts)
if err != nil {
log.Error("error updating collection", zap.Error(err))
}
result = updatedCollection
} else {
result = collection
}
return nil
} else {
return common.ErrCollectionUniqueConstraintViolation
}
}

dbCollection := &dbmodel.Collection{
ID: createCollection.ID.String(),
Name: &createCollection.Name,
Expand Down Expand Up @@ -242,15 +271,16 @@ func (tc *Catalog) CreateCollection(ctx context.Context, createCollection *model
log.Error("error getting collection", zap.Error(err))
return err
}
ressult = convertCollectionToModel(collectionList)[0]
result = convertCollectionToModel(collectionList)[0]
result.Created = true
return nil
})
if err != nil {
log.Error("error creating collection", zap.Error(err))
return nil, err
}
log.Info("collection created", zap.Any("collection", ressult))
return ressult, nil
log.Info("collection created", zap.Any("collection", result))
return result, nil
}

func (tc *Catalog) GetCollections(ctx context.Context, collectionID types.UniqueID, collectionName *string, collectionTopic *string, tenandID string, databaseName string) ([]*model.Collection, error) {
Expand Down
1 change: 1 addition & 0 deletions go/coordinator/internal/model/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ type Collection struct {
Topic string
Dimension *int32
Metadata *CollectionMetadata[CollectionMetadataValueType]
Created bool
TenantID string
DatabaseName string
Ts types.Timestamp
Expand Down
43 changes: 43 additions & 0 deletions go/coordinator/internal/model/collection_metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package model

type CollectionMetadataValueType interface {
IsCollectionMetadataValueType()
Equals(other CollectionMetadataValueType) bool
}

type CollectionMetadataValueStringType struct {
Expand All @@ -10,18 +11,39 @@ type CollectionMetadataValueStringType struct {

func (s *CollectionMetadataValueStringType) IsCollectionMetadataValueType() {}

func (s *CollectionMetadataValueStringType) Equals(other CollectionMetadataValueType) bool {
if o, ok := other.(*CollectionMetadataValueStringType); ok {
return s.Value == o.Value
}
return false
}

type CollectionMetadataValueInt64Type struct {
Value int64
}

func (s *CollectionMetadataValueInt64Type) IsCollectionMetadataValueType() {}

func (s *CollectionMetadataValueInt64Type) Equals(other CollectionMetadataValueType) bool {
if o, ok := other.(*CollectionMetadataValueInt64Type); ok {
return s.Value == o.Value
}
return false
}

type CollectionMetadataValueFloat64Type struct {
Value float64
}

func (s *CollectionMetadataValueFloat64Type) IsCollectionMetadataValueType() {}

func (s *CollectionMetadataValueFloat64Type) Equals(other CollectionMetadataValueType) bool {
if o, ok := other.(*CollectionMetadataValueFloat64Type); ok {
return s.Value == o.Value
}
return false
}

type CollectionMetadata[T CollectionMetadataValueType] struct {
Metadata map[string]T
}
Expand All @@ -47,3 +69,24 @@ func (m *CollectionMetadata[T]) Remove(key string) {
func (m *CollectionMetadata[T]) Empty() bool {
return len(m.Metadata) == 0
}

func (m *CollectionMetadata[T]) Equals(other *CollectionMetadata[T]) bool {
if m == nil && other == nil {
return true
}
if m == nil && other != nil {
return false
}
if m != nil && other == nil {
return false
}
if len(m.Metadata) != len(other.Metadata) {
return false
}
for key, value := range m.Metadata {
if otherValue, ok := other.Metadata[key]; !ok || !value.Equals(otherValue) {
return false
}
}
return true
}
Loading

0 comments on commit 233a7cc

Please sign in to comment.