From 64645ef3a48c192cb330e3b15f8446124fe7a898 Mon Sep 17 00:00:00 2001 From: Spike Lu Date: Thu, 22 Aug 2024 00:02:52 -0700 Subject: [PATCH] move ids gen to server side --- chromadb/api/models/CollectionCommon.py | 45 ++++--------------------- chromadb/api/types.py | 14 +++++--- chromadb/errors.py | 11 ++++++ chromadb/server/fastapi/__init__.py | 38 ++++++++++++++++++++- 4 files changed, 64 insertions(+), 44 deletions(-) diff --git a/chromadb/api/models/CollectionCommon.py b/chromadb/api/models/CollectionCommon.py index 934659d536b..ce40d8bf13c 100644 --- a/chromadb/api/models/CollectionCommon.py +++ b/chromadb/api/models/CollectionCommon.py @@ -171,7 +171,7 @@ def _unpack_embedding_set( unpacked_uris = maybe_cast_one_to_many_uri(uris) return { - "ids": unpacked_ids, + "ids": unpacked_ids if unpacked_ids is not None else [], "embeddings": unpacked_embeddings, "metadatas": unpacked_metadatas, "documents": unpacked_documents, @@ -188,8 +188,9 @@ def _validate_embedding_set( images: Optional[Images], uris: Optional[URIs], require_embeddings_or_data: bool = True, + can_ids_be_empty: bool = False, ) -> None: - valid_ids = validate_ids(ids) + valid_ids = validate_ids(ids, can_ids_be_empty=can_ids_be_empty) valid_embeddings = ( validate_embeddings(embeddings) if embeddings is not None else None ) @@ -426,33 +427,6 @@ def _update_model_after_modify_success( if metadata: self._model["metadata"] = metadata - @staticmethod - def _generate_ids_when_not_present( - ids: Optional[IDs], - documents: Optional[Documents], - uris: Optional[URIs], - images: Optional[Images], - embeddings: Optional[Embeddings], - ) -> IDs: - if ids is not None and len(ids) > 0: - return ids - - n = 0 - if documents is not None: - n = len(documents) - elif uris is not None: - n = len(uris) - elif images is not None: - n = len(images) - elif embeddings is not None: - n = len(embeddings) - - generated_ids = [] - for _ in range(n): - generated_ids.append(str(uuid4())) - - return generated_ids - def _process_add_request( self, ids: Optional[OneOrMany[ID]], @@ -482,22 +456,15 @@ def _process_add_request( else None ) - generated_ids = self._generate_ids_when_not_present( - unpacked_embedding_set["ids"], - unpacked_embedding_set["documents"], - unpacked_embedding_set["uris"], - unpacked_embedding_set["images"], - normalized_embeddings, - ) - self._validate_embedding_set( - generated_ids, + unpacked_embedding_set["ids"], normalized_embeddings, unpacked_embedding_set["metadatas"], unpacked_embedding_set["documents"], unpacked_embedding_set["images"], unpacked_embedding_set["uris"], require_embeddings_or_data=False, + can_ids_be_empty=True, ) prepared_embeddings = self._compute_embeddings( @@ -508,7 +475,7 @@ def _process_add_request( ) return { - "ids": generated_ids, + "ids": unpacked_embedding_set["ids"], "embeddings": prepared_embeddings, "metadatas": unpacked_embedding_set["metadatas"], "documents": unpacked_embedding_set["documents"], diff --git a/chromadb/api/types.py b/chromadb/api/types.py index 6a395d39379..7c19c2cb0a9 100644 --- a/chromadb/api/types.py +++ b/chromadb/api/types.py @@ -48,7 +48,11 @@ def maybe_cast_one_to_many_uri(target: Optional[OneOrMany[URI]]) -> Optional[URI IDs = List[ID] -def maybe_cast_one_to_many_ids(target: OneOrMany[ID]) -> IDs: +def maybe_cast_one_to_many_ids(target: Optional[OneOrMany[ID]]) -> Optional[IDs]: + # No target + if target is None: + return None + if isinstance(target, str): # One ID return cast(IDs, [target]) @@ -283,12 +287,14 @@ def __call__(self, uris: URIs) -> L: ... -def validate_ids(ids: IDs) -> IDs: +def validate_ids(ids: Optional[IDs], can_ids_be_empty: bool = False) -> IDs: """Validates ids to ensure it is a list of strings""" + + if can_ids_be_empty and (ids is None or len(ids) == 0): + return [] + if not isinstance(ids, list): raise ValueError(f"Expected IDs to be a list, got {type(ids).__name__} as IDs") - if len(ids) == 0: - raise ValueError(f"Expected IDs to be a non-empty list, got {len(ids)} IDs") seen = set() dups = set() for id_ in ids: diff --git a/chromadb/errors.py b/chromadb/errors.py index ff3a37a8692..1ca07c3f69c 100644 --- a/chromadb/errors.py +++ b/chromadb/errors.py @@ -45,6 +45,17 @@ def name(cls) -> str: return "IDAlreadyExists" +class NoRecordsError(ChromaError): + @overrides + def code(self) -> int: + return 400 # Bad Request + + @classmethod + @overrides + def name(cls) -> str: + return "NoRecordsError" + + class ChromaAuthError(ChromaError): @overrides def code(self) -> int: diff --git a/chromadb/server/fastapi/__init__.py b/chromadb/server/fastapi/__init__.py index bf146a2dd78..d0c0878058a 100644 --- a/chromadb/server/fastapi/__init__.py +++ b/chromadb/server/fastapi/__init__.py @@ -21,7 +21,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.routing import APIRoute from fastapi import HTTPException, status -from uuid import UUID +from uuid import UUID, uuid4 from chromadb.api.configuration import CollectionConfigurationInternal from pydantic import BaseModel @@ -35,6 +35,7 @@ from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System from chromadb.api import ServerAPI from chromadb.errors import ( + NoRecordsError, ChromaError, InvalidDimensionException, InvalidHTTPVersion, @@ -744,6 +745,34 @@ async def delete_collection( limiter=self._capacity_limiter, ) + @staticmethod + def _generate_ids_when_not_present( + ids: Optional[List[str]], + documents: Optional[List[Optional[str]]], + uris: Optional[List[Optional[str]]], + embeddings: Optional[List[Any]], + ) -> List[str]: + if ids is not None and len(ids) > 0: + return ids + + n = 0 + if documents is not None: + n = len(documents) + elif uris is not None: + n = len(uris) + elif embeddings is not None: + n = len(embeddings) + + generated_ids = [] + for _ in range(n): + generated_ids.append(str(uuid4())) + + return generated_ids + + def _validate_add_embedding(self, add: AddEmbedding): + if len(add.ids) == 0: + raise NoRecordsError("No records to add") + @trace_method("FastAPI.add", OpenTelemetryGranularity.OPERATION) async def add( self, request: Request, collection_id: str, body: AddEmbedding = Body(...) @@ -759,6 +788,13 @@ def process_add(request: Request, raw_body: bytes) -> bool: None, collection_id, ) + + ids = self._generate_ids_when_not_present( + add.ids, add.documents, add.uris, add.embeddings + ) + add.ids = ids + self._validate_add_embedding(add) + return self._api._add( collection_id=_uuid(collection_id), ids=add.ids,