Skip to content

Commit

Permalink
move ids gen to server side
Browse files Browse the repository at this point in the history
  • Loading branch information
spikechroma committed Aug 22, 2024
1 parent 71e3b00 commit 64645ef
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 44 deletions.
45 changes: 6 additions & 39 deletions chromadb/api/models/CollectionCommon.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def _unpack_embedding_set(
unpacked_uris = maybe_cast_one_to_many_uri(uris)

return {
"ids": unpacked_ids,
"ids": unpacked_ids if unpacked_ids is not None else [],
"embeddings": unpacked_embeddings,
"metadatas": unpacked_metadatas,
"documents": unpacked_documents,
Expand All @@ -188,8 +188,9 @@ def _validate_embedding_set(
images: Optional[Images],
uris: Optional[URIs],
require_embeddings_or_data: bool = True,
can_ids_be_empty: bool = False,
) -> None:
valid_ids = validate_ids(ids)
valid_ids = validate_ids(ids, can_ids_be_empty=can_ids_be_empty)
valid_embeddings = (
validate_embeddings(embeddings) if embeddings is not None else None
)
Expand Down Expand Up @@ -426,33 +427,6 @@ def _update_model_after_modify_success(
if metadata:
self._model["metadata"] = metadata

@staticmethod
def _generate_ids_when_not_present(
ids: Optional[IDs],
documents: Optional[Documents],
uris: Optional[URIs],
images: Optional[Images],
embeddings: Optional[Embeddings],
) -> IDs:
if ids is not None and len(ids) > 0:
return ids

n = 0
if documents is not None:
n = len(documents)
elif uris is not None:
n = len(uris)
elif images is not None:
n = len(images)
elif embeddings is not None:
n = len(embeddings)

generated_ids = []
for _ in range(n):
generated_ids.append(str(uuid4()))

return generated_ids

def _process_add_request(
self,
ids: Optional[OneOrMany[ID]],
Expand Down Expand Up @@ -482,22 +456,15 @@ def _process_add_request(
else None
)

generated_ids = self._generate_ids_when_not_present(
unpacked_embedding_set["ids"],
unpacked_embedding_set["documents"],
unpacked_embedding_set["uris"],
unpacked_embedding_set["images"],
normalized_embeddings,
)

self._validate_embedding_set(
generated_ids,
unpacked_embedding_set["ids"],
normalized_embeddings,
unpacked_embedding_set["metadatas"],
unpacked_embedding_set["documents"],
unpacked_embedding_set["images"],
unpacked_embedding_set["uris"],
require_embeddings_or_data=False,
can_ids_be_empty=True,
)

prepared_embeddings = self._compute_embeddings(
Expand All @@ -508,7 +475,7 @@ def _process_add_request(
)

return {
"ids": generated_ids,
"ids": unpacked_embedding_set["ids"],
"embeddings": prepared_embeddings,
"metadatas": unpacked_embedding_set["metadatas"],
"documents": unpacked_embedding_set["documents"],
Expand Down
14 changes: 10 additions & 4 deletions chromadb/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@ def maybe_cast_one_to_many_uri(target: Optional[OneOrMany[URI]]) -> Optional[URI
IDs = List[ID]


def maybe_cast_one_to_many_ids(target: OneOrMany[ID]) -> IDs:
def maybe_cast_one_to_many_ids(target: Optional[OneOrMany[ID]]) -> Optional[IDs]:
# No target
if target is None:
return None

if isinstance(target, str):
# One ID
return cast(IDs, [target])
Expand Down Expand Up @@ -283,12 +287,14 @@ def __call__(self, uris: URIs) -> L:
...


def validate_ids(ids: IDs) -> IDs:
def validate_ids(ids: Optional[IDs], can_ids_be_empty: bool = False) -> IDs:
"""Validates ids to ensure it is a list of strings"""

if can_ids_be_empty and (ids is None or len(ids) == 0):
return []

if not isinstance(ids, list):
raise ValueError(f"Expected IDs to be a list, got {type(ids).__name__} as IDs")
if len(ids) == 0:
raise ValueError(f"Expected IDs to be a non-empty list, got {len(ids)} IDs")
seen = set()
dups = set()
for id_ in ids:
Expand Down
11 changes: 11 additions & 0 deletions chromadb/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,17 @@ def name(cls) -> str:
return "IDAlreadyExists"


class NoRecordsError(ChromaError):
@overrides
def code(self) -> int:
return 400 # Bad Request

@classmethod
@overrides
def name(cls) -> str:
return "NoRecordsError"


class ChromaAuthError(ChromaError):
@overrides
def code(self) -> int:
Expand Down
38 changes: 37 additions & 1 deletion chromadb/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from fastapi.middleware.cors import CORSMiddleware
from fastapi.routing import APIRoute
from fastapi import HTTPException, status
from uuid import UUID
from uuid import UUID, uuid4

from chromadb.api.configuration import CollectionConfigurationInternal
from pydantic import BaseModel
Expand All @@ -35,6 +35,7 @@
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System
from chromadb.api import ServerAPI
from chromadb.errors import (
NoRecordsError,
ChromaError,
InvalidDimensionException,
InvalidHTTPVersion,
Expand Down Expand Up @@ -744,6 +745,34 @@ async def delete_collection(
limiter=self._capacity_limiter,
)

@staticmethod
def _generate_ids_when_not_present(
ids: Optional[List[str]],
documents: Optional[List[Optional[str]]],
uris: Optional[List[Optional[str]]],
embeddings: Optional[List[Any]],
) -> List[str]:
if ids is not None and len(ids) > 0:
return ids

n = 0
if documents is not None:
n = len(documents)
elif uris is not None:
n = len(uris)
elif embeddings is not None:
n = len(embeddings)

generated_ids = []
for _ in range(n):
generated_ids.append(str(uuid4()))

return generated_ids

def _validate_add_embedding(self, add: AddEmbedding):
if len(add.ids) == 0:
raise NoRecordsError("No records to add")

@trace_method("FastAPI.add", OpenTelemetryGranularity.OPERATION)
async def add(
self, request: Request, collection_id: str, body: AddEmbedding = Body(...)
Expand All @@ -759,6 +788,13 @@ def process_add(request: Request, raw_body: bytes) -> bool:
None,
collection_id,
)

ids = self._generate_ids_when_not_present(
add.ids, add.documents, add.uris, add.embeddings
)
add.ids = ids
self._validate_add_embedding(add)

return self._api._add(
collection_id=_uuid(collection_id),
ids=add.ids,
Expand Down

0 comments on commit 64645ef

Please sign in to comment.