From 4a5b473acc73cc200abbf8c1aab056a2ac6b6c22 Mon Sep 17 00:00:00 2001 From: spikechroma Date: Wed, 21 Aug 2024 14:22:12 -0700 Subject: [PATCH] [ENH] Generate IDs when not given in upsert and add (#2693) ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Refactored functions to adhere to SRP to reduce the level of abstractions. - New functionality - when a user uses add and upsert on an collection, they are no longer required to pass in an array of IDs. They will be automatically generated if not given. ## Test plan *How are these changes tested?* - [ ] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Documentation Changes *Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs repository](https://github.com/chroma-core/docs)?* --- chromadb/api/models/AsyncCollection.py | 17 +-- chromadb/api/models/Collection.py | 19 ++-- chromadb/api/models/CollectionCommon.py | 126 ++++++++-------------- chromadb/api/types.py | 8 ++ chromadb/test/property/strategies.py | 25 +++-- chromadb/test/property/test_embeddings.py | 23 +++- chromadb/test/test_api.py | 5 + 7 files changed, 117 insertions(+), 106 deletions(-) diff --git a/chromadb/api/models/AsyncCollection.py b/chromadb/api/models/AsyncCollection.py index 0f7b1c393ff..641118ab810 100644 --- a/chromadb/api/models/AsyncCollection.py +++ b/chromadb/api/models/AsyncCollection.py @@ -8,6 +8,7 @@ from chromadb.api.types import ( URI, + AddResult, CollectionMetadata, Embedding, Include, @@ -33,7 +34,7 @@ class AsyncCollection(CollectionCommon["AsyncServerAPI"]): async def add( self, - ids: OneOrMany[ID], + ids: Optional[OneOrMany[ID]] = None, embeddings: Optional[ Union[ OneOrMany[Embedding], @@ -44,7 +45,7 @@ async def add( documents: Optional[OneOrMany[Document]] = None, images: Optional[OneOrMany[Image]] = None, uris: Optional[OneOrMany[URI]] = None, - ) -> None: + ) -> AddResult: """Add embeddings to the data store. Args: ids: The ids of the embeddings you wish to add @@ -75,7 +76,7 @@ async def add( ) await self._client._add( - embedding_set["ids"], + cast(IDs, embedding_set["ids"]), self.id, cast(Embeddings, embedding_set["embeddings"]), embedding_set["metadatas"], @@ -83,6 +84,10 @@ async def add( embedding_set["uris"], ) + return { + "ids": embedding_set["ids"], + } + async def count(self) -> int: """The total number of embeddings added to the database @@ -259,7 +264,7 @@ async def update( Returns: None """ - embedding_set = self._process_update_request( + embedding_set = self._process_upsert_or_update_request( ids, embeddings, metadatas, documents, images, uris ) @@ -297,13 +302,13 @@ async def upsert( Returns: None """ - embedding_set = self._process_upsert_request( + embedding_set = self._process_upsert_or_update_request( ids, embeddings, metadatas, documents, images, uris ) await self._client._upsert( collection_id=self.id, - ids=embedding_set["ids"], + ids=cast(IDs, embedding_set["ids"]), embeddings=cast(Embeddings, embedding_set["embeddings"]), metadatas=embedding_set["metadatas"], documents=embedding_set["documents"], diff --git a/chromadb/api/models/Collection.py b/chromadb/api/models/Collection.py index f4acdd2068a..791b8fe53af 100644 --- a/chromadb/api/models/Collection.py +++ b/chromadb/api/models/Collection.py @@ -9,6 +9,7 @@ Include, Metadata, Document, + AddResult, Image, Where, IDs, @@ -40,7 +41,7 @@ def count(self) -> int: def add( self, - ids: OneOrMany[ID], + ids: Optional[OneOrMany[ID]] = None, embeddings: Optional[ # type: ignore[type-arg] Union[ OneOrMany[Embedding], @@ -51,7 +52,7 @@ def add( documents: Optional[OneOrMany[Document]] = None, images: Optional[OneOrMany[Image]] = None, uris: Optional[OneOrMany[URI]] = None, - ) -> None: + ) -> AddResult: """Add embeddings to the data store. Args: ids: The ids of the embeddings you wish to add @@ -82,7 +83,7 @@ def add( ) self._client._add( - embedding_set["ids"], + cast(IDs, embedding_set["ids"]), self.id, cast(Embeddings, embedding_set["embeddings"]), embedding_set["metadatas"], @@ -90,6 +91,10 @@ def add( embedding_set["uris"], ) + return { + "ids": embedding_set["ids"], + } + def get( self, ids: Optional[OneOrMany[ID]] = None, @@ -257,13 +262,13 @@ def update( Returns: None """ - embedding_set = self._process_update_request( + embedding_set = self._process_upsert_or_update_request( ids, embeddings, metadatas, documents, images, uris ) self._client._update( self.id, - embedding_set["ids"], + cast(IDs, embedding_set["ids"]), cast(Embeddings, embedding_set["embeddings"]), embedding_set["metadatas"], embedding_set["documents"], @@ -295,13 +300,13 @@ def upsert( Returns: None """ - embedding_set = self._process_upsert_request( + embedding_set = self._process_upsert_or_update_request( ids, embeddings, metadatas, documents, images, uris ) self._client._upsert( collection_id=self.id, - ids=embedding_set["ids"], + ids=cast(IDs, embedding_set["ids"]), embeddings=cast(Embeddings, embedding_set["embeddings"]), metadatas=embedding_set["metadatas"], documents=embedding_set["documents"], diff --git a/chromadb/api/models/CollectionCommon.py b/chromadb/api/models/CollectionCommon.py index 8b81b9bc992..707931e86b7 100644 --- a/chromadb/api/models/CollectionCommon.py +++ b/chromadb/api/models/CollectionCommon.py @@ -10,7 +10,7 @@ cast, ) import numpy as np -from uuid import UUID +from uuid import UUID, uuid4 import chromadb.utils.embedding_functions as ef from chromadb.api.types import ( @@ -151,7 +151,7 @@ def get_model(self) -> CollectionModel: def _unpack_embedding_set( self, - ids: OneOrMany[ID], + ids: Optional[OneOrMany[ID]], embeddings: Optional[ Union[ OneOrMany[Embedding], @@ -181,7 +181,7 @@ def _unpack_embedding_set( def _validate_embedding_set( self, - ids: IDs, + ids: Optional[IDs], embeddings: Optional[Embeddings], metadatas: Optional[Metadatas], documents: Optional[Documents], @@ -197,10 +197,6 @@ def _validate_embedding_set( validate_metadatas(metadatas) if metadatas is not None else None ) - valid_documents = maybe_cast_one_to_many_document(documents) - valid_images = maybe_cast_one_to_many_image(images) - valid_uris = maybe_cast_one_to_many_uri(uris) - # Check that one of embeddings or ducuments or images is provided if require_embeddings_or_data: if ( @@ -214,7 +210,7 @@ def _validate_embedding_set( ) # Only one of documents or images can be provided - if valid_documents is not None and valid_images is not None: + if documents is not None and images is not None: raise ValueError("You can only provide documents or images, not both.") # Check that, if they're provided, the lengths of the arrays match the length of ids @@ -226,17 +222,17 @@ def _validate_embedding_set( raise ValueError( f"Number of metadatas {len(valid_metadatas)} must match number of ids {len(valid_ids)}" ) - if valid_documents is not None and len(valid_documents) != len(valid_ids): + if documents is not None and len(documents) != len(valid_ids): raise ValueError( - f"Number of documents {len(valid_documents)} must match number of ids {len(valid_ids)}" + f"Number of documents {len(documents)} must match number of ids {len(valid_ids)}" ) - if valid_images is not None and len(valid_images) != len(valid_ids): + if images is not None and len(images) != len(valid_ids): raise ValueError( - f"Number of images {len(valid_images)} must match number of ids {len(valid_ids)}" + f"Number of images {len(images)} must match number of ids {len(valid_ids)}" ) - if valid_uris is not None and len(valid_uris) != len(valid_ids): + if uris is not None and len(uris) != len(valid_ids): raise ValueError( - f"Number of uris {len(valid_uris)} must match number of ids {len(valid_ids)}" + f"Number of uris {len(uris)} must match number of ids {len(valid_ids)}" ) def _prepare_embeddings( @@ -426,9 +422,36 @@ def _update_model_after_modify_success( if metadata: self._model["metadata"] = metadata + @staticmethod + def _generate_ids_when_not_present( + ids: Optional[IDs], + documents: Optional[Documents], + uris: Optional[URIs], + images: Optional[Images], + embeddings: Optional[Embeddings], + ) -> IDs: + if ids is not None and len(ids) > 0: + return ids + + n = 0 + if documents is not None: + n = len(documents) + elif uris is not None: + n = len(uris) + elif images is not None: + n = len(images) + elif embeddings is not None: + n = len(embeddings) + + generated_ids = [] + for _ in range(n): + generated_ids.append(str(uuid4())) + + return generated_ids + def _process_add_request( self, - ids: OneOrMany[ID], + ids: Optional[OneOrMany[ID]], embeddings: Optional[ Union[ OneOrMany[Embedding], @@ -455,72 +478,16 @@ def _process_add_request( else None ) - self._validate_embedding_set( + generated_ids = self._generate_ids_when_not_present( unpacked_embedding_set["ids"], - normalized_embeddings, - unpacked_embedding_set["metadatas"], unpacked_embedding_set["documents"], - unpacked_embedding_set["images"], unpacked_embedding_set["uris"], - require_embeddings_or_data=False, - ) - - prepared_embeddings = self._prepare_embeddings( - normalized_embeddings, - unpacked_embedding_set["documents"], unpacked_embedding_set["images"], - unpacked_embedding_set["uris"], - ) - - return { - "ids": unpacked_embedding_set["ids"], - "embeddings": prepared_embeddings, - "metadatas": unpacked_embedding_set["metadatas"], - "documents": unpacked_embedding_set["documents"], - "images": unpacked_embedding_set["images"], - "uris": unpacked_embedding_set["uris"], - } - - def _prepare_update_request( - self, - embeddings: Optional[Embeddings], - documents: Optional[Documents], - images: Optional[Images], - ) -> Embeddings: - if embeddings is None: - if documents is not None: - embeddings = self._embed(input=documents) - elif images is not None: - embeddings = self._embed(input=images) - - return cast(Embeddings, embeddings) - - def _process_update_request( - self, - ids: OneOrMany[ID], - embeddings: Optional[ # type: ignore[type-arg] - Union[ - OneOrMany[Embedding], - OneOrMany[np.ndarray], - ] - ], - metadatas: Optional[OneOrMany[Metadata]], - documents: Optional[OneOrMany[Document]], - images: Optional[OneOrMany[Image]], - uris: Optional[OneOrMany[URI]], - ) -> EmbeddingSet: - unpacked_embedding_set = self._unpack_embedding_set( - ids, embeddings, metadatas, documents, images, uris - ) - - normalized_embeddings = ( - self._normalize_embeddings(unpacked_embedding_set["embeddings"]) - if unpacked_embedding_set["embeddings"] is not None - else None + normalized_embeddings, ) self._validate_embedding_set( - unpacked_embedding_set["ids"], + generated_ids, normalized_embeddings, unpacked_embedding_set["metadatas"], unpacked_embedding_set["documents"], @@ -529,14 +496,15 @@ def _process_update_request( require_embeddings_or_data=False, ) - prepared_embeddings = self._prepare_update_request( + prepared_embeddings = self._prepare_embeddings( normalized_embeddings, unpacked_embedding_set["documents"], unpacked_embedding_set["images"], + unpacked_embedding_set["uris"], ) return { - "ids": unpacked_embedding_set["ids"], + "ids": generated_ids, "embeddings": prepared_embeddings, "metadatas": unpacked_embedding_set["metadatas"], "documents": unpacked_embedding_set["documents"], @@ -544,7 +512,7 @@ def _process_update_request( "uris": unpacked_embedding_set["uris"], } - def _prepare_upsert_request( + def _prepare_upsert_or_update_request( self, embeddings: Optional[Embeddings], documents: Optional[Documents], @@ -558,7 +526,7 @@ def _prepare_upsert_request( return cast(Embeddings, embeddings) - def _process_upsert_request( + def _process_upsert_or_update_request( self, ids: OneOrMany[ID], embeddings: Optional[ # type: ignore[type-arg] @@ -592,7 +560,7 @@ def _process_upsert_request( require_embeddings_or_data=False, ) - prepared_embeddings = self._prepare_upsert_request( + prepared_embeddings = self._prepare_upsert_or_update_request( normalized_embeddings, unpacked_embedding_set["documents"], unpacked_embedding_set["images"], diff --git a/chromadb/api/types.py b/chromadb/api/types.py index a8dcf2b643c..336127bcdab 100644 --- a/chromadb/api/types.py +++ b/chromadb/api/types.py @@ -201,6 +201,10 @@ class EmbeddingSet(TypedDict): L = TypeVar("L", covariant=True, bound=Loadable) +class AddResult(TypedDict): + ids: IDs + + class GetResult(TypedDict): ids: List[ID] embeddings: Optional[List[Embedding]] @@ -292,6 +296,10 @@ def validate_ids(ids: IDs) -> IDs: for id_ in ids: if not isinstance(id_, str): raise ValueError(f"Expected ID to be a str, got {id_}") + + if len(id_) == 0: + raise ValueError("Expected ID to be a non-empty str, got an empty string") + if id_ in seen: dups.add(id_) else: diff --git a/chromadb/test/property/strategies.py b/chromadb/test/property/strategies.py index 28f20b940c7..3074e68886c 100644 --- a/chromadb/test/property/strategies.py +++ b/chromadb/test/property/strategies.py @@ -458,17 +458,26 @@ def recordsets( num_unique_metadata: Optional[int] = None, min_metadata_size: int = 0, max_metadata_size: Optional[int] = None, + # ids can only be optional for add operations + for_add: bool = False, ) -> RecordSet: collection = draw(collection_strategy) - ids = list( draw(st.lists(id_strategy, min_size=min_size, max_size=max_size, unique=True)) ) + # This probablistic event is used to mimic user behavior when they don't provide ids + if for_add and np.random.rand() < 0.5: + ids = [] + + n = len(ids) + if len(ids) == 0: + n = int(draw(st.integers(min_value=min_size, max_value=max_size))) + embeddings: Optional[Embeddings] = None if collection.has_embeddings: - embeddings = create_embeddings(collection.dimension, len(ids), collection.dtype) - num_metadata = num_unique_metadata if num_unique_metadata is not None else len(ids) + embeddings = create_embeddings(collection.dimension, n, collection.dtype) + num_metadata = num_unique_metadata if num_unique_metadata is not None else n generated_metadatas = draw( st.lists( metadata( @@ -479,20 +488,18 @@ def recordsets( ) ) metadatas = [] - for i in range(len(ids)): + for i in range(n): metadatas.append(generated_metadatas[i % len(generated_metadatas)]) documents: Optional[Documents] = None if collection.has_documents: - documents = draw( - st.lists(document(collection), min_size=len(ids), max_size=len(ids)) - ) + documents = draw(st.lists(document(collection), min_size=n, max_size=n)) # in the case where we have a single record, sometimes exercise # the code that handles individual values rather than lists. # In this case, any field may be a list or a single value. - if len(ids) == 1: - single_id: Union[str, List[str]] = ids[0] if draw(st.booleans()) else ids + if n == 1: + single_id: Union[str, List[str]] = ids[0] if len(ids) == 1 else ids single_embedding = ( embeddings[0] if embeddings is not None and draw(st.booleans()) diff --git a/chromadb/test/property/test_embeddings.py b/chromadb/test/property/test_embeddings.py index dc53bbc52d7..07817552ad5 100644 --- a/chromadb/test/property/test_embeddings.py +++ b/chromadb/test/property/test_embeddings.py @@ -8,7 +8,13 @@ from hypothesis import given, settings, HealthCheck from typing import Dict, Set, cast, Union, DefaultDict, Any, List from dataclasses import dataclass -from chromadb.api.types import ID, Embeddings, Include, IDs, validate_embeddings +from chromadb.api.types import ( + ID, + Embeddings, + Include, + IDs, + validate_embeddings, +) from chromadb.config import System import chromadb.errors as errors from chromadb.api import ClientAPI @@ -104,7 +110,7 @@ def teardown(self) -> None: @rule( target=embedding_ids, - record_set=strategies.recordsets(collection_st), + record_set=strategies.recordsets(collection_st, for_add=True), ) def add_embeddings(self, record_set: strategies.RecordSet) -> MultipleResults[ID]: trace("add_embeddings") @@ -114,7 +120,10 @@ def add_embeddings(self, record_set: strategies.RecordSet) -> MultipleResults[ID record_set ) - if len(normalized_record_set["ids"]) > 0: + if ( + normalized_record_set["metadatas"] is not None + and len(normalized_record_set["metadatas"]) > 0 + ): trace("add_more_embeddings") intersection = set(normalized_record_set["ids"]).intersection( @@ -141,9 +150,13 @@ def add_embeddings(self, record_set: strategies.RecordSet) -> MultipleResults[ID return multiple(*filtered_record_set["ids"]) else: - self.collection.add(**normalized_record_set) # type: ignore[arg-type] + result = self.collection.add(**normalized_record_set) # type: ignore[arg-type] + ids = result["ids"] + normalized_record_set["ids"] = ids + self._upsert_embeddings(cast(strategies.RecordSet, normalized_record_set)) - return multiple(*normalized_record_set["ids"]) + + return multiple(*ids) @rule(ids=st.lists(consumes(embedding_ids), min_size=1)) def delete_by_ids(self, ids: IDs) -> None: diff --git a/chromadb/test/test_api.py b/chromadb/test/test_api.py index 7b9bc763fff..5701ea8c143 100644 --- a/chromadb/test/test_api.py +++ b/chromadb/test/test_api.py @@ -1207,6 +1207,11 @@ def test_invalid_id(client): collection.add(embeddings=[0, 0, 0], ids=[1], metadatas=[{}]) assert "ID" in str(e.value) + # Upsert with an empty id + with pytest.raises(ValueError) as e: + collection.upsert(embeddings=[0, 0, 0], ids=[""]) + assert "non-empty" in str(e.value) + # Get with non-list id with pytest.raises(ValueError) as e: collection.get(ids=1)