diff --git a/chromadb/api/__init__.py b/chromadb/api/__init__.py index 261d25c76b6..b890af8a262 100644 --- a/chromadb/api/__init__.py +++ b/chromadb/api/__init__.py @@ -24,6 +24,7 @@ URIs, Where, QueryResult, + AddResult, GetResult, WhereDocument, ) @@ -115,13 +116,13 @@ def delete_collection( @abstractmethod def _add( self, - ids: IDs, collection_id: UUID, embeddings: Embeddings, + ids: Optional[IDs] = None, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, uris: Optional[URIs] = None, - ) -> bool: + ) -> AddResult: """[Internal] Add embeddings to a collection specified by UUID. If (some) ids already exist, only the new embeddings will be added. diff --git a/chromadb/api/async_api.py b/chromadb/api/async_api.py index d4e6a7a3a7b..050560e3dc0 100644 --- a/chromadb/api/async_api.py +++ b/chromadb/api/async_api.py @@ -24,6 +24,7 @@ Where, QueryResult, GetResult, + AddResult, WhereDocument, ) from chromadb.config import Component, Settings @@ -106,13 +107,13 @@ async def delete_collection( @abstractmethod async def _add( self, - ids: IDs, collection_id: UUID, embeddings: Embeddings, + ids: Optional[IDs] = None, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, uris: Optional[URIs] = None, - ) -> bool: + ) -> AddResult: """[Internal] Add embeddings to a collection specified by UUID. If (some) ids already exist, only the new embeddings will be added. diff --git a/chromadb/api/async_client.py b/chromadb/api/async_client.py index 32b0b8cf50f..c17f4f513cd 100644 --- a/chromadb/api/async_client.py +++ b/chromadb/api/async_client.py @@ -14,6 +14,7 @@ EmbeddingFunction, Embeddings, GetResult, + AddResult, IDs, Include, Loadable, @@ -256,13 +257,13 @@ async def delete_collection( @override async def _add( self, - ids: IDs, collection_id: UUID, embeddings: Embeddings, + ids: Optional[IDs] = None, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, uris: Optional[URIs] = None, - ) -> bool: + ) -> AddResult: return await self._server._add( ids=ids, collection_id=collection_id, diff --git a/chromadb/api/async_fastapi.py b/chromadb/api/async_fastapi.py index 489dc4d1a03..fc15c0a3552 100644 --- a/chromadb/api/async_fastapi.py +++ b/chromadb/api/async_fastapi.py @@ -2,7 +2,7 @@ from uuid import UUID import urllib.parse import orjson -from typing import Any, Optional, cast, Tuple, Sequence, Dict +from typing import Any, Optional, cast, Sequence, Dict import logging import httpx from overrides import override @@ -31,9 +31,11 @@ Where, WhereDocument, GetResult, + AddResult, QueryResult, CollectionMetadata, - validate_batch, + validate_batch_size, + RecordSet, convert_np_embeddings_to_list, ) @@ -413,16 +415,10 @@ async def _delete( return cast(IDs, resp_json) - @trace_method("AsyncFastAPI._submit_batch", OpenTelemetryGranularity.ALL) - async def _submit_batch( + @trace_method("AsyncFastAPI._submit_record_set", OpenTelemetryGranularity.ALL) + async def _submit_record_set( self, - batch: Tuple[ - IDs, - Optional[PyEmbeddings], - Optional[Metadatas], - Optional[Documents], - Optional[URIs], - ], + record_set: RecordSet, url: str, ) -> Any: """ @@ -432,11 +428,11 @@ async def _submit_batch( "post", url, json={ - "ids": batch[0], - "embeddings": batch[1], - "metadatas": batch[2], - "documents": batch[3], - "uris": batch[4], + "ids": record_set["ids"], + "embeddings": record_set["embeddings"], + "metadatas": record_set["metadatas"], + "documents": record_set["documents"], + "uris": record_set["uris"], }, ) @@ -444,23 +440,43 @@ async def _submit_batch( @override async def _add( self, - ids: IDs, collection_id: UUID, embeddings: Embeddings, + ids: Optional[IDs] = None, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, uris: Optional[URIs] = None, - ) -> bool: - batch = ( - ids, - convert_np_embeddings_to_list(embeddings), - metadatas, - documents, - uris, + ) -> AddResult: + record_set: RecordSet = { + "ids": ids, + "embeddings": convert_np_embeddings_to_list(embeddings) + if embeddings is not None + else None, + "metadatas": metadatas, + "documents": documents, + "uris": uris, + "images": None, + } + + validate_batch_size( + record_set, {"max_batch_size": await self.get_max_batch_size()} + ) + + resp_json = await self._make_request( + "post", + "/collections/" + str(collection_id) + "/add", + json={ + "ids": record_set["ids"], + "embeddings": record_set["embeddings"], + "metadatas": record_set["metadatas"], + "documents": record_set["documents"], + "uris": record_set["uris"], + }, + ) + + return AddResult( + ids=resp_json["ids"], ) - validate_batch(batch, {"max_batch_size": await self.get_max_batch_size()}) - await self._submit_batch(batch, "/collections/" + str(collection_id) + "/add") - return True @trace_method("AsyncFastAPI._update", OpenTelemetryGranularity.ALL) @override @@ -473,19 +489,23 @@ async def _update( documents: Optional[Documents] = None, uris: Optional[URIs] = None, ) -> bool: - batch = ( - ids, - convert_np_embeddings_to_list(embeddings) + record_set: RecordSet = { + "ids": ids, + "embeddings": convert_np_embeddings_to_list(embeddings) if embeddings is not None else None, - metadatas, - documents, - uris, + "metadatas": metadatas, + "documents": documents, + "uris": uris, + "images": None, + } + + validate_batch_size( + record_set, {"max_batch_size": await self.get_max_batch_size()} ) - validate_batch(batch, {"max_batch_size": await self.get_max_batch_size()}) - await self._submit_batch( - batch, "/collections/" + str(collection_id) + "/update" + await self._submit_record_set( + record_set, "/collections/" + str(collection_id) + "/update" ) return True @@ -501,17 +521,25 @@ async def _upsert( documents: Optional[Documents] = None, uris: Optional[URIs] = None, ) -> bool: - batch = ( - ids, - convert_np_embeddings_to_list(embeddings), - metadatas, - documents, - uris, + record_set: RecordSet = { + "ids": ids, + "embeddings": convert_np_embeddings_to_list(embeddings) + if embeddings is not None + else None, + "metadatas": metadatas, + "documents": documents, + "uris": uris, + "images": None, + } + + validate_batch_size( + record_set, {"max_batch_size": await self.get_max_batch_size()} ) - validate_batch(batch, {"max_batch_size": await self.get_max_batch_size()}) - await self._submit_batch( - batch, "/collections/" + str(collection_id) + "/upsert" + + await self._submit_record_set( + record_set, "/collections/" + str(collection_id) + "/upsert" ) + return True @trace_method("AsyncFastAPI._query", OpenTelemetryGranularity.ALL) diff --git a/chromadb/api/client.py b/chromadb/api/client.py index c5625423868..9996ab1db8d 100644 --- a/chromadb/api/client.py +++ b/chromadb/api/client.py @@ -19,6 +19,7 @@ Loadable, Metadatas, QueryResult, + AddResult, URIs, ) from chromadb.config import Settings, System @@ -208,13 +209,13 @@ def delete_collection( @override def _add( self, - ids: IDs, collection_id: UUID, embeddings: Embeddings, + ids: Optional[IDs] = None, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, uris: Optional[URIs] = None, - ) -> bool: + ) -> AddResult: return self._server._add( ids=ids, collection_id=collection_id, diff --git a/chromadb/api/fastapi.py b/chromadb/api/fastapi.py index f71720a1152..7d719de1572 100644 --- a/chromadb/api/fastapi.py +++ b/chromadb/api/fastapi.py @@ -1,6 +1,6 @@ import orjson import logging -from typing import Any, Dict, Optional, cast, Tuple +from typing import Any, Dict, Optional, cast from typing import Sequence from uuid import UUID import httpx @@ -19,14 +19,16 @@ IDs, Include, Metadatas, + RecordSet, URIs, Where, WhereDocument, GetResult, + AddResult, QueryResult, CollectionMetadata, convert_np_embeddings_to_list, - validate_batch, + validate_batch_size, ) from chromadb.auth import ( ClientAuthProvider, @@ -378,16 +380,10 @@ def _delete( ) return cast(IDs, resp_json) - @trace_method("FastAPI._submit_batch", OpenTelemetryGranularity.ALL) - def _submit_batch( + @trace_method("FastAPI._submit_record_set", OpenTelemetryGranularity.ALL) + def _submit_record_set( self, - batch: Tuple[ - IDs, - Optional[PyEmbeddings], - Optional[Metadatas], - Optional[Documents], - Optional[URIs], - ], + record_set: RecordSet, url: str, ) -> None: """ @@ -397,11 +393,11 @@ def _submit_batch( "post", url, json={ - "ids": batch[0], - "embeddings": batch[1], - "metadatas": batch[2], - "documents": batch[3], - "uris": batch[4], + "ids": record_set["ids"], + "embeddings": record_set["embeddings"], + "metadatas": record_set["metadatas"], + "documents": record_set["documents"], + "uris": record_set["uris"], }, ) @@ -409,27 +405,51 @@ def _submit_batch( @override def _add( self, - ids: IDs, collection_id: UUID, embeddings: Embeddings, + ids: Optional[IDs] = None, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, uris: Optional[URIs] = None, - ) -> bool: + ) -> AddResult: """ Adds a batch of embeddings to the database - pass in column oriented data lists """ - batch = ( + record_set: RecordSet = { + "ids": ids, + "embeddings": convert_np_embeddings_to_list(embeddings), - metadatas, + + "metadatas": metadatas, + "documents": documents, + "uris": uris, + , + "images": None, + } + + validate_batch_size(record_set, {"max_batch_size": self.get_max_batch_size()}) + + # This differs from the request for update and upsert because we want to return the ids, + # which are generated at the server (segment) + resp_json = self._make_request( + "post", + "/collections/" + str(collection_id) + "/add", + json={ + "ids": record_set["ids"], + "embeddings": record_set["embeddings"], + "metadatas": record_set["metadatas"], + "documents": record_set["documents"], + "uris": record_set["uris"], + }, + ) + + return AddResult( + ids=resp_json["ids"], ) - validate_batch(batch, {"max_batch_size": self.get_max_batch_size()}) - self._submit_batch(batch, "/collections/" + str(collection_id) + "/add") - return True @trace_method("FastAPI._update", OpenTelemetryGranularity.ALL) @override @@ -446,17 +466,19 @@ def _update( Updates a batch of embeddings in the database - pass in column oriented data lists """ - batch = ( - ids, - convert_np_embeddings_to_list(embeddings) - if embeddings is not None - else None, - metadatas, - documents, - uris, + record_set: RecordSet = { + "ids": ids, + "embeddings": convert_np_embeddings_to_list(embeddings) if embeddings is not None else None, + "metadatas": metadatas, + "documents": documents, + "uris": uris, + "images": None, + } + + validate_batch_size(record_set, {"max_batch_size": self.get_max_batch_size()}) + self._submit_record_set( + record_set, "/collections/" + str(collection_id) + "/update" ) - validate_batch(batch, {"max_batch_size": self.get_max_batch_size()}) - self._submit_batch(batch, "/collections/" + str(collection_id) + "/update") return True @trace_method("FastAPI._upsert", OpenTelemetryGranularity.ALL) @@ -474,15 +496,20 @@ def _upsert( Upserts a batch of embeddings in the database - pass in column oriented data lists """ - batch = ( - ids, - convert_np_embeddings_to_list(embeddings), - metadatas, - documents, - uris, + record_set: RecordSet = { + "ids": ids, + "embeddings": convert_np_embeddings_to_list(embeddings) if embeddings is not None else None, + "metadatas": metadatas, + "documents": documents, + "uris": uris, + "images": None, + } + + validate_batch_size(record_set, {"max_batch_size": self.get_max_batch_size()}) + + self._submit_record_set( + record_set, "/collections/" + str(collection_id) + "/upsert" ) - validate_batch(batch, {"max_batch_size": self.get_max_batch_size()}) - self._submit_batch(batch, "/collections/" + str(collection_id) + "/upsert") return True @trace_method("FastAPI._query", OpenTelemetryGranularity.ALL) diff --git a/chromadb/api/models/AsyncCollection.py b/chromadb/api/models/AsyncCollection.py index c577ce4d2b8..2ebe067e2a1 100644 --- a/chromadb/api/models/AsyncCollection.py +++ b/chromadb/api/models/AsyncCollection.py @@ -20,6 +20,7 @@ Where, IDs, GetResult, + AddResult, QueryResult, ID, OneOrMany, @@ -35,7 +36,7 @@ class AsyncCollection(CollectionCommon["AsyncServerAPI"]): async def add( self, - ids: OneOrMany[ID], + ids: Optional[OneOrMany[ID]] = None, embeddings: Optional[ # type: ignore[type-arg] Union[ OneOrMany[Embedding], @@ -46,7 +47,7 @@ async def add( documents: Optional[OneOrMany[Document]] = None, images: Optional[OneOrMany[Image]] = None, uris: Optional[OneOrMany[URI]] = None, - ) -> None: + ) -> AddResult: """Add embeddings to the data store. Args: ids: The ids of the embeddings you wish to add @@ -76,7 +77,7 @@ async def add( uris=uris, ) - await self._client._add( + result = await self._client._add( collection_id=self.id, ids=record_set["ids"], embeddings=cast(Embeddings, record_set["embeddings"]), @@ -85,6 +86,8 @@ async def add( uris=record_set["uris"], ) + return result + async def count(self) -> int: """The total number of embeddings added to the database @@ -272,7 +275,8 @@ async def update( await self._client._update( collection_id=self.id, - ids=record_set["ids"], + # TODO: We slightly abuse the RecordSet type here because on the type IDs could be None + ids=cast(IDs, record_set["ids"]), embeddings=cast(Embeddings, record_set["embeddings"]), metadatas=record_set["metadatas"], documents=record_set["documents"], @@ -315,7 +319,8 @@ async def upsert( await self._client._upsert( collection_id=self.id, - ids=record_set["ids"], + # TODO: We slightly abuse the RecordSet type here because on the type IDs could be None + ids=cast(IDs, record_set["ids"]), embeddings=cast(Embeddings, record_set["embeddings"]), metadatas=record_set["metadatas"], documents=record_set["documents"], diff --git a/chromadb/api/models/Collection.py b/chromadb/api/models/Collection.py index 042d695c42a..5e57e8a5217 100644 --- a/chromadb/api/models/Collection.py +++ b/chromadb/api/models/Collection.py @@ -15,6 +15,7 @@ Where, IDs, GetResult, + AddResult, QueryResult, ID, OneOrMany, @@ -41,7 +42,7 @@ def count(self) -> int: def add( self, - ids: OneOrMany[ID], + ids: Optional[OneOrMany[ID]] = None, embeddings: Optional[ # type: ignore[type-arg] Union[ OneOrMany[Embedding], @@ -52,7 +53,7 @@ def add( documents: Optional[OneOrMany[Document]] = None, images: Optional[OneOrMany[Image]] = None, uris: Optional[OneOrMany[URI]] = None, - ) -> None: + ) -> AddResult: """Add embeddings to the data store. Args: ids: The ids of the embeddings you wish to add @@ -82,7 +83,7 @@ def add( uris=uris, ) - self._client._add( + result = self._client._add( collection_id=self.id, ids=record_set["ids"], embeddings=cast(Embeddings, record_set["embeddings"]), @@ -91,6 +92,8 @@ def add( uris=record_set["uris"], ) + return result + def get( self, ids: Optional[OneOrMany[ID]] = None, @@ -270,7 +273,8 @@ def update( self._client._update( collection_id=self.id, - ids=record_set["ids"], + # TODO: We slightly abuse the RecordSet type here because on the type IDs could be None + ids=cast(IDs, record_set["ids"]), embeddings=cast(Embeddings, record_set["embeddings"]), metadatas=record_set["metadatas"], documents=record_set["documents"], @@ -313,7 +317,8 @@ def upsert( self._client._upsert( collection_id=self.id, - ids=record_set["ids"], + # TODO: We slightly abuse the RecordSet type here because on the type IDs could be None + ids=cast(IDs, record_set["ids"]), embeddings=cast(Embeddings, record_set["embeddings"]), metadatas=record_set["metadatas"], documents=record_set["documents"], diff --git a/chromadb/api/models/CollectionCommon.py b/chromadb/api/models/CollectionCommon.py index d55b80d69a0..97923b7aec3 100644 --- a/chromadb/api/models/CollectionCommon.py +++ b/chromadb/api/models/CollectionCommon.py @@ -157,11 +157,12 @@ def _unpack_record_set( ] = None, metadatas: Optional[OneOrMany[Metadata]] = None, documents: Optional[OneOrMany[Document]] = None, + ids: Optional[OneOrMany[ID]] = None, images: Optional[OneOrMany[Image]] = None, uris: Optional[OneOrMany[URI]] = None, ) -> RecordSet: return { - "ids": cast(IDs, maybe_cast_one_to_many(ids)), + "ids": maybe_cast_one_to_many(ids), "embeddings": maybe_cast_one_to_many_embedding(embeddings), "metadatas": maybe_cast_one_to_many(metadatas), "documents": maybe_cast_one_to_many(documents), @@ -185,7 +186,7 @@ def _compute_embeddings( else: if uris is None: raise ValueError( - "You must provide either embeddings, documents, images, or uris." + "You must provide either embeddings, documents, images, or URIs." ) if self._data_loader is None: raise ValueError( @@ -354,7 +355,6 @@ def _update_model_after_modify_success( def _process_add_request( self, - ids: OneOrMany[ID], embeddings: Optional[ # type: ignore[type-arg] Union[ OneOrMany[Embedding], @@ -363,6 +363,7 @@ def _process_add_request( ] = None, metadatas: Optional[OneOrMany[Metadata]] = None, documents: Optional[OneOrMany[Document]] = None, + ids: Optional[OneOrMany[ID]] = None, images: Optional[OneOrMany[Image]] = None, uris: Optional[OneOrMany[URI]] = None, ) -> RecordSet: @@ -412,6 +413,9 @@ def _process_upsert_request( uris=uris, ) + if record_set["ids"] is None: + raise ValueError("You must provide ids.") + validate_record_set( record_set, require_data=True, @@ -450,6 +454,9 @@ def _process_update_request( uris=uris, ) + if record_set["ids"] is None: + raise ValueError("You must provide ids.") + validate_record_set( record_set, require_data=False, diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index 70243a80517..19b6ac7e145 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -37,13 +37,15 @@ Include, RecordSet, GetResult, + AddResult, QueryResult, validate_metadata, validate_update_metadata, validate_where, validate_where_document, - validate_batch, + validate_batch_size, validate_record_set, + count_records, ) from chromadb.telemetry.product.events import ( CollectionAddEvent, @@ -334,38 +336,47 @@ def delete_collection( @override def _add( self, - ids: IDs, collection_id: UUID, embeddings: Embeddings, + ids: Optional[IDs] = None, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, uris: Optional[URIs] = None, - ) -> bool: + ) -> AddResult: self._quota.static_check(metadatas, documents, embeddings, str(collection_id)) coll = self._get_collection(collection_id) self._manager.hint_use_collection(collection_id, t.Operation.ADD) + # TODO: We slightly abuse the record_set type here + record_set: RecordSet = { + "ids": ids, + "embeddings": embeddings, + "documents": documents, + "uris": uris, + "metadatas": metadatas, + "images": None, + } + n = count_records(record_set) + + # Generate ids if not provided + if record_set["ids"] is None: + n = count_records(record_set) + record_set["ids"] = [str(uuid4()) for _ in range(n)] + self._validate_record_set( collection=coll, - record_set={ - "ids": ids, - "embeddings": embeddings, - "documents": documents, - "uris": uris, - "metadatas": metadatas, - "images": None, - }, + record_set=record_set, require_data=True, ) records_to_submit = list( _records( t.Operation.ADD, - ids=ids, - embeddings=embeddings, - metadatas=metadatas, - documents=documents, - uris=uris, + ids=cast(IDs, record_set["ids"]), # IDs validator checks for None + embeddings=record_set["embeddings"], + metadatas=record_set["metadatas"], + documents=record_set["documents"], + uris=record_set["uris"], ) ) self._producer.submit_embeddings(collection_id, records_to_submit) @@ -373,13 +384,15 @@ def _add( self._product_telemetry_client.capture( CollectionAddEvent( collection_uuid=str(collection_id), - add_amount=len(ids), - with_metadata=len(ids) if metadatas is not None else 0, - with_documents=len(ids) if documents is not None else 0, - with_uris=len(ids) if uris is not None else 0, + add_amount=n, + with_metadata=n if metadatas is not None else 0, + with_documents=n if documents is not None else 0, + with_uris=n if uris is not None else 0, ) ) - return True + return AddResult( + ids=cast(IDs, record_set["ids"]) + ) # IDs validator checks for None @trace_method("SegmentAPI._update", OpenTelemetryGranularity.OPERATION) @override @@ -891,14 +904,8 @@ def _validate_record_set( try: validate_record_set(record_set, require_data=require_data) - validate_batch( - ( - record_set["ids"], - record_set["embeddings"], - record_set["metadatas"], - record_set["documents"], - record_set["uris"], - ), + validate_batch_size( + record_set, {"max_batch_size": self.get_max_batch_size()}, ) diff --git a/chromadb/api/types.py b/chromadb/api/types.py index 58104434b66..f0398d3b060 100644 --- a/chromadb/api/types.py +++ b/chromadb/api/types.py @@ -1,4 +1,4 @@ -from typing import Optional, Union, TypeVar, List, Dict, Any, Tuple, cast +from typing import Optional, Union, TypeVar, List, Dict, Any, cast from numpy.typing import NDArray import numpy as np from typing_extensions import TypedDict, Protocol, runtime_checkable @@ -163,7 +163,7 @@ class IncludeEnum(str, Enum): class RecordSet(TypedDict): - ids: IDs + ids: Optional[IDs] embeddings: Optional[Embeddings] metadatas: Optional[Metadatas] documents: Optional[Documents] @@ -227,6 +227,10 @@ class GetResult(TypedDict): included: Include +class AddResult(TypedDict): + ids: List[ID] + + class QueryResult(TypedDict): ids: List[IDs] embeddings: Optional[ @@ -308,17 +312,21 @@ def __call__(self, uris: URIs) -> L: ... -def validate_ids(ids: IDs) -> IDs: +def validate_ids(ids: Optional[IDs]) -> IDs: """Validates ids to ensure it is a list of strings""" + if ids is None: + raise ValueError("Expected IDs to be a non-empty list of str, got None") if not isinstance(ids, list): raise ValueError(f"Expected IDs to be a list, got {type(ids).__name__} as IDs") if len(ids) == 0: raise ValueError(f"Expected IDs to be a non-empty list, got {len(ids)} IDs") + seen = set() dups = set() for id_ in ids: if not isinstance(id_, str): raise ValueError(f"Expected ID to be a str, got {id_}") + if id_ in seen: dups.add(id_) else: @@ -600,26 +608,54 @@ def validate_embeddings(embeddings: Embeddings) -> Embeddings: return embeddings -def validate_batch( - batch: Tuple[ - IDs, - Optional[Union[Embeddings, PyEmbeddings]], - Optional[Metadatas], - Optional[Documents], - Optional[URIs], - ], +def validate_batch_size( + record_set: RecordSet, limits: Dict[str, Any], ) -> None: - if len(batch[0]) > limits["max_batch_size"]: + batch_size = count_records(record_set) + + if batch_size > limits["max_batch_size"]: raise ValueError( - f"Batch size {len(batch[0])} exceeds maximum batch size {limits['max_batch_size']}" + f"Batch size {batch_size} exceeds maximum batch size {limits['max_batch_size']}" + ) + + +def validate_record_set_count(record_set: RecordSet) -> None: + """ + Validate the consistency of the record set, ensuring all values are non-empty lists and have the same length. + """ + error_messages = [] + counts = { + field: len( + cast(Union[IDs, Metadatas, Embeddings, Documents, Images, URIs], value) + ) + for field, value in record_set.items() + if value is not None + } + + if len(counts) == 0: + raise ValueError("Expected at least one record set field to be non-empty") + + if any(count == 0 for count in counts.values()): + error_messages.append( + f"Expected all fields to be non-empty lists, got empty lists in: {', '.join(field for field, count in counts.items() if count == 0)}" ) + if len(counts) > 1: + field_record_counts = [f"{field}: ({count})" for field, count in counts.items()] + error_messages.append( + f"Inconsistent number of records: {', '.join(field_record_counts)}" + ) + + if error_messages: + raise ValueError(", ".join(error_messages)) + def validate_record_set( record_set: RecordSet, require_data: bool, ) -> None: + validate_record_set_count(record_set) validate_ids(record_set["ids"]) validate_embeddings(record_set["embeddings"]) if record_set[ "embeddings" @@ -639,7 +675,8 @@ def validate_record_set( if not record_set_contains_one_of(record_set, include=required_fields): raise ValueError(f"You must provide one of {', '.join(required_fields)}") - valid_ids = record_set["ids"] + # The ID validator checks for None + valid_ids = cast(IDs, record_set["ids"]) for key in ["embeddings", "metadatas", "documents", "images", "uris"]: if record_set[key] is not None and len(record_set[key]) != len(valid_ids): # type: ignore[literal-required] raise ValueError( @@ -647,6 +684,23 @@ def validate_record_set( ) +def count_records( + record_set: RecordSet, +) -> int: + """ + Get the number of items in the record set. + """ + + validate_record_set_count(record_set) + for value in record_set.values(): + if value is not None: + return len( + cast(Union[IDs, Embeddings, Metadatas, Documents, Images, URIs], value) + ) + + raise ValueError("Expected at least one record set field to be non-empty") + + def convert_np_embeddings_to_list(embeddings: Embeddings) -> PyEmbeddings: return [embedding.tolist() for embedding in embeddings] diff --git a/chromadb/errors.py b/chromadb/errors.py index 99bc1bc8970..ca8e74a932b 100644 --- a/chromadb/errors.py +++ b/chromadb/errors.py @@ -41,6 +41,17 @@ def name(cls) -> str: return "InvalidCollection" +class InvalidInputError(ChromaError): + @overrides + def code(self) -> int: + return 400 # Bad Request + + @classmethod + @overrides + def name(cls) -> str: + return "InvalidInput" + + class IDAlreadyExistsError(ChromaError): @overrides def code(self) -> int: diff --git a/chromadb/server/fastapi/__init__.py b/chromadb/server/fastapi/__init__.py index 910fc1e400e..de1a370a5d5 100644 --- a/chromadb/server/fastapi/__init__.py +++ b/chromadb/server/fastapi/__init__.py @@ -31,7 +31,7 @@ QueryResult, Embeddings, convert_list_embeddings_to_np, -) +), AddResult from chromadb.auth import ( AuthzAction, AuthzResource, @@ -768,10 +768,10 @@ async def delete_collection( @trace_method("FastAPI.add", OpenTelemetryGranularity.OPERATION) async def add( self, request: Request, collection_id: str, body: AddEmbedding = Body(...) - ) -> bool: + ) -> AddResult: try: - def process_add(request: Request, raw_body: bytes) -> bool: + def process_add(request: Request, raw_body: bytes) -> AddResult: add = validate_model(AddEmbedding, orjson.loads(raw_body)) self.auth_and_get_tenant_and_database_for_request( request.headers, @@ -795,7 +795,7 @@ def process_add(request: Request, raw_body: bytes) -> bool: ) return cast( - bool, + AddResult, await to_thread.run_sync( process_add, request, diff --git a/chromadb/server/fastapi/types.py b/chromadb/server/fastapi/types.py index f644ff7883b..7cff7dc2e54 100644 --- a/chromadb/server/fastapi/types.py +++ b/chromadb/server/fastapi/types.py @@ -16,7 +16,7 @@ class AddEmbedding(BaseModel): metadatas: Optional[List[Optional[Dict[Any, Any]]]] = None documents: Optional[List[Optional[str]]] = None uris: Optional[List[Optional[str]]] = None - ids: List[str] + ids: Optional[List[str]] = None class UpdateEmbedding(BaseModel): diff --git a/chromadb/test/api/test_api_add.py b/chromadb/test/api/test_api_add.py new file mode 100644 index 00000000000..f957c31b34c --- /dev/null +++ b/chromadb/test/api/test_api_add.py @@ -0,0 +1,109 @@ +import pytest + +from chromadb.api import ClientAPI +from chromadb.test.conftest import reset + + +def test_add_with_no_ids(client: ClientAPI) -> None: + reset(client) + + coll = client.create_collection("test") + coll.add( + embeddings=[[1, 2, 3], [1, 2, 3], [1, 2, 3]], # type: ignore[arg-type] + metadatas=[{"a": 1}, None, {"a": 3}], # type: ignore[list-item] + documents=["a", "b", None], # type: ignore[list-item] + ) + + results = coll.get() + assert len(results["ids"]) == 3 + + coll.add( + embeddings=[[1, 2, 3], [1, 2, 3], [1, 2, 3]], # type: ignore[arg-type] + metadatas=[{"a": 1}, None, {"a": 3}], # type: ignore[list-item] + documents=["a", "b", None], # type: ignore[list-item] + ) + + results = coll.get() + assert len(results["ids"]) == 6 + + +def test_add_with_inconsistent_number_of_items(client: ClientAPI) -> None: + reset(client) + + coll = client.create_collection("test") + + # Test case 1: Inconsistent number of ids + with pytest.raises(ValueError, match="Inconsistent number of records"): + coll.add( + ids=["1", "2"], + embeddings=[[1, 2, 3], [1, 2, 3], [1, 2, 3]], # type: ignore[arg-type] + metadatas=[{"a": 1}, {"a": 2}, {"a": 3}], + documents=["a", "b", "c"], + ) + + # Test case 2: Inconsistent number of embeddings + with pytest.raises(ValueError, match="Inconsistent number of records"): + coll.add( + ids=["1", "2", "3"], + embeddings=[[1, 2, 3], [1, 2, 3]], # type: ignore[arg-type] + metadatas=[{"a": 1}, {"a": 2}, {"a": 3}], + documents=["a", "b", "c"], + ) + + # Test case 3: Inconsistent number of metadatas + with pytest.raises(ValueError, match="Inconsistent number of records"): + coll.add( + ids=["1", "2", "3"], + embeddings=[[1, 2, 3], [1, 2, 3], [1, 2, 3]], # type: ignore[arg-type] + metadatas=[{"a": 1}, {"a": 2}], + documents=["a", "b", "c"], + ) + + # Test case 4: Inconsistent number of documents + with pytest.raises(ValueError, match="Inconsistent number of records"): + coll.add( + ids=["1", "2", "3"], + embeddings=[[1, 2, 3], [1, 2, 3], [1, 2, 3]], # type: ignore[arg-type] + metadatas=[{"a": 1}, {"a": 2}, {"a": 3}], + documents=["a", "b"], + ) + + # Test case 5: Multiple inconsistencies + with pytest.raises(ValueError, match="Inconsistent number of records"): + coll.add( + ids=["1", "2"], + embeddings=[[1, 2, 3], [1, 2, 3], [1, 2, 3]], # type: ignore[arg-type] + metadatas=[{"a": 1}], + documents=["a", "b", "c", "d"], + ) + + +def test_add_with_partial_ids(client: ClientAPI) -> None: + reset(client) + + coll = client.create_collection("test") + + with pytest.raises(ValueError, match="Expected ID to be a str"): + coll.add( + ids=["1", None], # type: ignore[list-item] + embeddings=[[1, 2, 3], [1, 2, 3], [1, 2, 3]], # type: ignore[arg-type] + metadatas=[{"a": 1}, None, {"a": 3}], # type: ignore[list-item] + documents=["a", "b", None], # type: ignore[list-item] + ) + + +def test_add_with_no_data(client: ClientAPI) -> None: + reset(client) + + coll = client.create_collection("test") + + with pytest.raises( + Exception, + match="Expected embeddings to be a list or a numpy array with at least one item", + ): + coll.add( + ids=["1"], + embeddings=[], + metadatas=[{"a": 1}], + documents=[], + ) diff --git a/chromadb/test/api/test_api_update.py b/chromadb/test/api/test_api_update.py index 32bcdb35f86..d21df85c41b 100644 --- a/chromadb/test/api/test_api_update.py +++ b/chromadb/test/api/test_api_update.py @@ -19,3 +19,11 @@ def test_update_query_with_none_data(client: ClientAPI) -> None: assert e.match( "You must provide one of embeddings, documents, images, uris, metadatas" ) + + +def test_update_with_none_ids(client: ClientAPI) -> None: + client.reset() + collection = client.create_collection("test") + with pytest.raises(ValueError) as e: + collection.update(ids=None, embeddings=[[0.1, 0.2, 0.3]]) # type: ignore[arg-type] + assert e.match("You must provide ids when updating.") diff --git a/chromadb/test/api/test_api_upsert.py b/chromadb/test/api/test_api_upsert.py new file mode 100644 index 00000000000..50596ecf0fe --- /dev/null +++ b/chromadb/test/api/test_api_upsert.py @@ -0,0 +1,10 @@ +import pytest +from chromadb.api import ClientAPI + + +def test_upsert_with_none_ids(client: ClientAPI) -> None: + client.reset() + collection = client.create_collection("test") + with pytest.raises(ValueError) as e: + collection.upsert(ids=None, embeddings=[[0.1, 0.2, 0.3]]) # type: ignore[arg-type] + assert e.match("You must provide ids when upserting.") diff --git a/chromadb/test/api/test_validations.py b/chromadb/test/api/test_validations.py index c35499864a6..2939036c681 100644 --- a/chromadb/test/api/test_validations.py +++ b/chromadb/test/api/test_validations.py @@ -1,13 +1,19 @@ import pytest import numpy as np +from typing import cast from chromadb.api.types import ( + Embeddings, + IDs, RecordSet, record_set_contains_one_of, maybe_cast_one_to_many_embedding, validate_embeddings, - Embeddings, + validate_ids, + validate_record_set_count, ) +import chromadb.errors as errors + def test_does_record_set_contain_any_data() -> None: valid_record_set: RecordSet = { @@ -29,19 +35,18 @@ def test_does_record_set_contain_any_data() -> None: "uris": None, } - with pytest.raises(ValueError) as e: + with pytest.raises(ValueError, match="Expected embeddings to be a non-empty list"): record_set_contains_one_of(record_set_non_list, include=["embeddings"]) # type: ignore[list-item] - assert "Expected embeddings to be a non-empty list" in str(e) - # Test case 2: Non-list field - with pytest.raises(ValueError) as e: + with pytest.raises(ValueError, match="Expected include to be a non-empty list"): record_set_contains_one_of(valid_record_set, include=[]) - assert "Expected include to be a non-empty list" in str(e) - # Test case 3: Non-existent field - with pytest.raises(ValueError) as e: + with pytest.raises( + ValueError, + match="Expected include key to be a a known field of RecordSet, got non_existent_field", + ): record_set_contains_one_of(valid_record_set, include=["non_existent_field"]) # type: ignore[list-item] assert ( @@ -103,26 +108,165 @@ def test_maybe_cast_one_to_many_embedding() -> None: def test_embeddings_validation() -> None: invalid_embeddings = [[0, 0, True], [1.2, 2.24, 3.2]] - with pytest.raises(ValueError) as e: + with pytest.raises( + ValueError, match="Expected each value in the embedding to be a int or float" + ): validate_embeddings(invalid_embeddings) # type: ignore[arg-type] - assert "Expected each value in the embedding to be a int or float" in str(e) - invalid_embeddings = [[0, 0, "invalid"], [1.2, 2.24, 3.2]] - with pytest.raises(ValueError) as e: + with pytest.raises( + ValueError, match="Expected each value in the embedding to be a int or float" + ): validate_embeddings(invalid_embeddings) # type: ignore[arg-type] - assert "Expected each value in the embedding to be a int or float" in str(e) - - with pytest.raises(ValueError) as e: + with pytest.raises(ValueError, match="Expected embeddings to be a list, got str"): validate_embeddings("invalid") # type: ignore[arg-type] - assert "Expected embeddings to be a list, got str" in str(e) - def test_0dim_embedding_validation() -> None: - embds: Embeddings = [[]] # type: ignore[list-item] - with pytest.raises(ValueError) as e: + embds: Embeddings = [[]] + with pytest.raises( + ValueError, + match="Expected each embedding in the embeddings to be a non-empty list", + ): validate_embeddings(embds) - assert "Expected each embedding in the embeddings to be a non-empty list" in str(e) + + +def test_ids_validation() -> None: + ids = ["id1", "id2", "id3"] + assert validate_ids(ids) == ids + + with pytest.raises(ValueError, match="Expected IDs to be a list"): + validate_ids(cast(IDs, "not a list")) + + with pytest.raises(ValueError, match="Expected IDs to be a non-empty list"): + validate_ids([]) + + with pytest.raises(ValueError, match="Expected ID to be a str"): + validate_ids(cast(IDs, ["id1", 123, "id3"])) + + with pytest.raises(errors.DuplicateIDError, match="Expected IDs to be unique"): + validate_ids(["id1", "id2", "id1"]) + + ids = [ + "id1", + "id2", + "id3", + "id4", + "id5", + "id6", + "id7", + "id8", + "id9", + "id10", + "id11", + "id12", + "id13", + "id14", + "id15", + ] * 2 + with pytest.raises(errors.DuplicateIDError, match="found 15 duplicated IDs: "): + validate_ids(ids) + + +def test_validate_record_set_consistency() -> None: + # Test record set with inconsistent lengths + inconsistent_record_set: RecordSet = { + "ids": ["1", "2"], + "embeddings": [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], + "metadatas": [{"key": "value1"}, {"key": "value2"}, {"key": "value3"}], + "documents": ["doc1", "doc2"], + "images": None, + "uris": None, + } + with pytest.raises(ValueError, match="Inconsistent number of records:"): + validate_record_set_count(inconsistent_record_set) + + # Test record set with empty list + empty_list_record_set: RecordSet = { + "ids": ["1", "2", "3"], + "embeddings": [], + "metadatas": [{"key": "value1"}, {"key": "value2"}, {"key": "value3"}], + "documents": ["doc1", "doc2", "doc3"], + "images": None, + "uris": None, + } + with pytest.raises(ValueError, match="got empty lists in: * embeddings"): + validate_record_set_count(empty_list_record_set) + + # Test record set with all None value + all_none_record_set: RecordSet = { + "ids": None, + "embeddings": None, + "metadatas": None, + "documents": None, + "images": None, + "uris": None, + } + with pytest.raises( + ValueError, match="Expected at least one record set field to be non-empty" + ): + validate_record_set_count(all_none_record_set) + + # Test record set with multiple errors + multiple_error_record_set: RecordSet = { + "ids": [], + "embeddings": "not a list", # type: ignore[typeddict-item] + "metadatas": [{"key": "value1"}, {"key": "value2"}], + "documents": ["doc1"], + "images": None, + "uris": None, + } + with pytest.raises(ValueError, match="got empty lists in: * ids") as exc_info: + validate_record_set_count(multiple_error_record_set) + + assert "Inconsistent number of records:" in str(exc_info.value) + + +def test_maybe_cast_one_to_many_embedding() -> None: + # Test with None input + assert maybe_cast_one_to_many_embedding(None) is None + + # Test with a single embedding as a list + single_embedding = [1.0, 2.0, 3.0] + result = maybe_cast_one_to_many_embedding(single_embedding) + assert result == [single_embedding] + + # Test with multiple embeddings as a list of lists + multiple_embeddings = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] + result = maybe_cast_one_to_many_embedding(multiple_embeddings) # type: ignore[arg-type] + assert result == multiple_embeddings + + # Test with a numpy array (single embedding) + np_single = np.array([1.0, 2.0, 3.0]) + result = maybe_cast_one_to_many_embedding(np_single) + assert isinstance(result, list) + assert len(result) == 1 + assert np.array_equal(result[0], np_single) + + # Test with a numpy array (multiple embeddings) + np_multiple = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + result = maybe_cast_one_to_many_embedding(np_multiple) + assert isinstance(result, list) + assert len(result) == 2 + assert np.array_equal(result, np_multiple) + + # Test with an empty list (should raise ValueError) + with pytest.raises( + ValueError, match="Expected embeddings to be a list with at least one item" + ): + maybe_cast_one_to_many_embedding([]) + + # Test with an empty list (should raise ValueError) + with pytest.raises( + ValueError, match="Expected embeddings to be a list with at least one item" + ): + maybe_cast_one_to_many_embedding(np.array([])) + + # Test with an empty str (should raise ValueError) + with pytest.raises( + ValueError, + match="Expected embeddings to be a list or a numpy array, got str", + ): + maybe_cast_one_to_many_embedding("") # type: ignore[arg-type] diff --git a/chromadb/test/distributed/test_sanity.py b/chromadb/test/distributed/test_sanity.py index 34f04759623..3e5b042f68e 100644 --- a/chromadb/test/distributed/test_sanity.py +++ b/chromadb/test/distributed/test_sanity.py @@ -49,7 +49,8 @@ def test_add( "metadatas": None, "documents": None, }, - 10, + n_records=len(ids), + n_results=10, query_embeddings=[random_query], ) @@ -93,6 +94,7 @@ def test_add_include_all_with_compaction_delay(client: ClientAPI) -> None: "metadatas": None, "documents": documents, }, - 10, + n_records=len(ids), + n_results=10, query_embeddings=[random_query_1, random_query_2], ) diff --git a/chromadb/test/property/invariants.py b/chromadb/test/property/invariants.py index 25fdc34f034..2e07451e162 100644 --- a/chromadb/test/property/invariants.py +++ b/chromadb/test/property/invariants.py @@ -5,7 +5,11 @@ from chromadb.db.impl.sqlite import SqliteDB from time import sleep import psutil -from chromadb.test.property.strategies import NormalizedRecordSet, RecordSet +from chromadb.test.property.strategies import ( + NormalizedRecordSet, + RecordSet, + StateMachineRecordSet, +) from typing import Callable, Optional, Tuple, Union, List, TypeVar, cast, Any from typing_extensions import Literal import numpy as np @@ -67,7 +71,7 @@ def wrap_all(record_set: RecordSet) -> NormalizedRecordSet: ) return { - "ids": wrap(record_set["ids"]), + "ids": wrap(record_set["ids"]) if record_set["ids"] is not None else None, "documents": wrap(record_set["documents"]) if record_set["documents"] is not None else None, @@ -78,11 +82,57 @@ def wrap_all(record_set: RecordSet) -> NormalizedRecordSet: } +def get_n_items_from_record_set_state(state_record_set: StateMachineRecordSet) -> int: + # we need to replace empty lists with None within the record set state to use get_n_items_from_record_set + # get_n_items_from_record_set would throw an error if it encounters an empty list + if all(len(value) == 0 for value in state_record_set.values()): # type: ignore[arg-type] + return 0 + + record_set_with_empty_lists_replaced: types.RecordSet = { + "ids": None, + "documents": None, + "metadatas": None, + "embeddings": None, + "images": None, + "uris": None, + } + + for key, value in state_record_set.items(): + record_set_with_empty_lists_replaced[key] = None if len(value) == 0 else value # type: ignore[literal-required, arg-type] + + return types.count_records(record_set_with_empty_lists_replaced) + + +def get_n_items_from_record_set(record_set: RecordSet) -> int: + """Get the number of items from a record set""" + normalized_record_set = wrap_all(record_set) + + return types.count_records( + { + "ids": normalized_record_set["ids"], + "embeddings": normalized_record_set["embeddings"], + "metadatas": cast(types.Metadatas, normalized_record_set["metadatas"]), + "documents": normalized_record_set["documents"], + "uris": None, + "images": None, + } + ) + + def count(collection: Collection, record_set: RecordSet) -> None: """The given collection count is equal to the number of embeddings""" count = collection.count() - normalized_record_set = wrap_all(record_set) - assert count == len(normalized_record_set["ids"]) + n = get_n_items_from_record_set(record_set) + assert count == n + + +def count_state_record_set( + collection: Collection, record_set: StateMachineRecordSet +) -> None: + """The given collection count is equal to the number of embeddings within the state record set""" + count = collection.count() + n = get_n_items_from_record_set_state(record_set) + assert count == n def _field_matches( @@ -91,22 +141,25 @@ def _field_matches( field_name: Union[ Literal["documents"], Literal["metadatas"], Literal["embeddings"] ], + n: int, ) -> None: """ - The actual embedding field is equal to the expected field + The actual record field is equal to the expected field field_name: one of [documents, metadatas] """ + # If there are no ids, then there are no data to test + if normalized_record_set["ids"] is None: + raise ValueError("IDs should not be None") + result = collection.get(ids=normalized_record_set["ids"], include=[field_name]) # type: ignore[list-item] + # The test_out_of_order_ids test fails because of this in test_add.py # Here we sort by the ids to match the input order embedding_id_to_index = {id: i for i, id in enumerate(normalized_record_set["ids"])} actual_field = result[field_name] - if len(normalized_record_set["ids"]) == 0: - if field_name == "embeddings": - assert cast(npt.NDArray[Any], actual_field).size == 0 - else: - assert actual_field == [] + if n == 0: + assert actual_field == [] return # This assert should never happen, if we include metadatas/documents it will be @@ -137,27 +190,90 @@ def ids_match(collection: Collection, record_set: RecordSet) -> None: actual_ids = collection.get(ids=normalized_record_set["ids"], include=[])["ids"] # The test_out_of_order_ids test fails because of this in test_add.py # Here we sort the ids to match the input order + + if normalized_record_set["ids"] is None: + raise ValueError("IDs should not be None") + embedding_id_to_index = {id: i for i, id in enumerate(normalized_record_set["ids"])} actual_ids = sorted(actual_ids, key=lambda id: embedding_id_to_index[id]) + assert actual_ids == normalized_record_set["ids"] def metadatas_match(collection: Collection, record_set: RecordSet) -> None: - """The actual embedding metadata is equal to the expected metadata""" + """The actual record set metadatas are equal to the expected metadatas""" normalized_record_set = wrap_all(record_set) - _field_matches(collection, normalized_record_set, "metadatas") + + _field_matches( + collection, + normalized_record_set, + "metadatas", + get_n_items_from_record_set(record_set), + ) + + +def metadatas_match_state_record_set( + collection: Collection, record_set: StateMachineRecordSet +) -> None: + """The actual metadatas within the state record set are equal to the expected metadata""" + normalized_record_set = wrap_all(cast(RecordSet, record_set)) + + _field_matches( + collection, + normalized_record_set, + "metadatas", + get_n_items_from_record_set_state(record_set), + ) def documents_match(collection: Collection, record_set: RecordSet) -> None: - """The actual embedding documents is equal to the expected documents""" + """The actual record set documents are equal to the expected documents""" normalized_record_set = wrap_all(record_set) - _field_matches(collection, normalized_record_set, "documents") + _field_matches( + collection, + normalized_record_set, + "documents", + get_n_items_from_record_set(record_set), + ) + + +def documents_match_state_record_set( + collection: Collection, record_set: StateMachineRecordSet +) -> None: + """The actual documents within the state record set are equal to the expected documents""" + normalized_record_set = wrap_all(cast(RecordSet, record_set)) + + _field_matches( + collection, + normalized_record_set, + "documents", + get_n_items_from_record_set_state(record_set), + ) def embeddings_match(collection: Collection, record_set: RecordSet) -> None: - """The actual embedding documents is equal to the expected documents""" + """The actual record set embeddings are equal to the expected embeddings""" normalized_record_set = wrap_all(record_set) - _field_matches(collection, normalized_record_set, "embeddings") + _field_matches( + collection, + normalized_record_set, + "embeddings", + get_n_items_from_record_set(record_set), + ) + + +def embeddings_match_state_record_set( + collection: Collection, record_set: StateMachineRecordSet +) -> None: + """The actual embeddings within the state record set are equal to the expected embeddings""" + normalized_record_set = wrap_all(cast(RecordSet, record_set)) + + _field_matches( + collection, + normalized_record_set, + "embeddings", + get_n_items_from_record_set_state(record_set), + ) def no_duplicates(collection: Collection) -> None: @@ -212,6 +328,7 @@ def fd_not_exceeding_threadpool_size(threadpool_size: int) -> None: def ann_accuracy( collection: Collection, record_set: RecordSet, + n_records: int, n_results: int = 1, min_recall: float = 0.99, embedding_function: Optional[types.EmbeddingFunction] = None, # type: ignore[type-arg] @@ -221,9 +338,12 @@ def ann_accuracy( """Validate that the API performs nearest_neighbor searches correctly""" normalized_record_set = wrap_all(record_set) - if len(normalized_record_set["ids"]) == 0: + if n_records == 0: return # nothing to test here + if normalized_record_set["ids"] is None: + raise ValueError("IDs should not be None") + embeddings: Optional[types.Embeddings] = normalized_record_set["embeddings"] have_embeddings = embeddings is not None and len(embeddings) > 0 if not have_embeddings: diff --git a/chromadb/test/property/strategies.py b/chromadb/test/property/strategies.py index 34a2b6b4067..9ee0dd50444 100644 --- a/chromadb/test/property/strategies.py +++ b/chromadb/test/property/strategies.py @@ -1,7 +1,7 @@ import hashlib import hypothesis import hypothesis.strategies as st -from typing import Any, Optional, List, Dict, Union, cast +from typing import Any, Optional, List, Dict, Union, cast, Tuple from typing_extensions import TypedDict import uuid import numpy as np @@ -17,6 +17,7 @@ EmbeddingFunction, Embeddings, Metadata, + IDs, ) from chromadb.types import LiteralValue, WhereOperator, LogicalOperator @@ -56,7 +57,7 @@ class RecordSet(TypedDict): represent what a user would pass to the API. """ - ids: Union[types.ID, List[types.ID]] + ids: Optional[Union[types.ID, List[types.ID]]] embeddings: Optional[Union[types.Embeddings, types.Embedding]] metadatas: Optional[Union[List[Optional[types.Metadata]], types.Metadata]] documents: Optional[Union[List[types.Document], types.Document]] @@ -67,7 +68,7 @@ class NormalizedRecordSet(TypedDict): A RecordSet, with all fields normalized to lists. """ - ids: List[types.ID] + ids: Optional[List[types.ID]] embeddings: Optional[types.Embeddings] metadatas: Optional[List[Optional[types.Metadata]]] documents: Optional[List[types.Document]] @@ -194,7 +195,7 @@ def create_embeddings_ndarray( dim: int, count: int, dtype: npt.DTypeLike, -) -> np.typing.NDArray[Any]: +) -> np.ndarray: # type: ignore[type-arg] return np.random.uniform( low=-1.0, high=1.0, @@ -411,11 +412,12 @@ def document(draw: st.DrawFn, collection: Collection) -> types.Document: # For cluster tests, we want to avoid generating documents of length < 3. # We also don't want them to contain certan special # characters like _ and % that implicitly involve searching for a regex in sqlite. + if not NOT_CLUSTER_ONLY: # Blacklist certain unicode characters that affect sqlite processing. # For example, the null (/x00) character makes sqlite stop processing a string. # Also, blacklist _ and % for cluster tests. - blacklist_categories = ("Cc", "Cs", "Pc", "Po") + blacklist_categories: Tuple[str, ...] = ("Cc", "Cs", "Pc", "Po") if collection.known_document_keywords: known_words_st = st.sampled_from(collection.known_document_keywords) else: @@ -461,17 +463,28 @@ def recordsets( num_unique_metadata: Optional[int] = None, min_metadata_size: int = 0, max_metadata_size: Optional[int] = None, + can_ids_be_empty: bool = False, ) -> RecordSet: collection = draw(collection_strategy) - ids = list( + # Generating an integer for n_records first, then a coinflip, then n_records IDs if the coinflip is heads + # creates a combinatorial explosion because we need to sample across all n_records and all possible IDs for n_records + # generating IDs and stomping them 50% of the avoids that. + ids: Optional[List[types.ID]] = list( draw(st.lists(id_strategy, min_size=min_size, max_size=max_size, unique=True)) ) + n_records = len(cast(IDs, ids)) + + if can_ids_be_empty and draw(st.booleans()): + ids = None + embeddings: Optional[Embeddings] = None if collection.has_embeddings: - embeddings = create_embeddings(collection.dimension, len(ids), collection.dtype) - num_metadata = num_unique_metadata if num_unique_metadata is not None else len(ids) + embeddings = create_embeddings( + collection.dimension, n_records, collection.dtype + ) + num_metadata = num_unique_metadata if num_unique_metadata is not None else n_records generated_metadatas = draw( st.lists( metadata( @@ -482,20 +495,24 @@ def recordsets( ) ) metadatas = [] - for i in range(len(ids)): + for i in range(n_records): metadatas.append(generated_metadatas[i % len(generated_metadatas)]) documents: Optional[Documents] = None if collection.has_documents: documents = draw( - st.lists(document(collection), min_size=len(ids), max_size=len(ids)) + st.lists(document(collection), min_size=n_records, max_size=n_records) ) # in the case where we have a single record, sometimes exercise # the code that handles individual values rather than lists. # In this case, any field may be a list or a single value. - if len(ids) == 1: - single_id: Union[str, List[str]] = ids[0] if draw(st.booleans()) else ids + if n_records == 1: + single_id: Optional[Union[str, List[str]]] = ( + ids[0] + if (ids is not None and len(ids) == 1 and draw(st.booleans())) + else ids + ) single_embedding = ( embeddings[0] if embeddings is not None and draw(st.booleans()) @@ -557,7 +574,7 @@ def where_clause(draw: st.DrawFn, collection: Collection) -> types.Where: if not NOT_CLUSTER_ONLY: legal_ops = [None, "$eq"] else: - legal_ops = [None, "$eq", "$ne", "$in", "$nin"] + legal_ops: List[Optional[str]] = [None, "$eq", "$ne", "$in", "$nin"] # type: ignore[no-redef] if not isinstance(value, str) and not isinstance(value, bool): legal_ops.extend(["$gt", "$lt", "$lte", "$gte"]) @@ -608,10 +625,10 @@ def where_doc_clause(draw: st.DrawFn, collection: Collection) -> types.WhereDocu else: op = draw(st.sampled_from(["$contains", "$not_contains"])) - if op == "$contains": + if op == "$contains": # type: ignore[comparison-overlap] return {"$contains": word} else: - assert op == "$not_contains" + assert op == "$not_contains" # type: ignore[comparison-overlap] return {"$not_contains": word} @@ -678,7 +695,10 @@ def filters( st.one_of(st.none(), recursive_where_doc_clause(collection)) ) - ids: Optional[Union[List[types.ID], types.ID]] + if recordset["ids"] is None: + raise ValueError("Record set IDs cannot be None") + + ids: Union[List[types.ID], types.ID] # Record sets can be a value instead of a list of values if there is only one record if isinstance(recordset["ids"], str): ids = [recordset["ids"]] @@ -686,7 +706,7 @@ def filters( ids = recordset["ids"] if not include_all_ids: - ids = draw(st.one_of(st.none(), st.lists(st.sampled_from(ids)))) + ids = draw(st.one_of(st.none(), st.lists(st.sampled_from(ids)))) # type: ignore[assignment] if ids is not None: # Remove duplicates since hypothesis samples with replacement ids = list(set(ids)) diff --git a/chromadb/test/property/test_add.py b/chromadb/test/property/test_add.py index 0284cfe7b3c..547136d74f6 100644 --- a/chromadb/test/property/test_add.py +++ b/chromadb/test/property/test_add.py @@ -26,7 +26,9 @@ # record sets so we explicitly create a large record set without using Hypothesis @given( collection=collection_st, - record_set=strategies.recordsets(collection_st, min_size=1, max_size=500), + record_set=strategies.recordsets( + collection_st, max_size=500, can_ids_be_empty=True + ), should_compact=st.booleans(), ) @settings( @@ -54,6 +56,7 @@ def test_add_small( num_unique_metadata=5, min_metadata_size=1, max_metadata_size=5, + can_ids_be_empty=True, ), should_compact=st.booleans(), ) @@ -104,43 +107,37 @@ def _test_add( # TODO: The type of add() is incorrect as it does not allow for metadatas # like [{"a": 1}, None, {"a": 3}] - for batch in create_batches( - api=client, - ids=cast(List[str], record_set["ids"]), - embeddings=cast(Embeddings, record_set["embeddings"]), - metadatas=cast(Metadatas, record_set["metadatas"]), - documents=cast(List[str], record_set["documents"]), - ): - coll.add(*batch) + result = coll.add(**record_set) # type: ignore[arg-type] + if normalized_record_set["ids"] is None: + normalized_record_set["ids"] = result["ids"] + + n_records = invariants.get_n_items_from_record_set(record_set) + # Only wait for compaction if the size of the collection is # some minimal size - if ( - not NOT_CLUSTER_ONLY - and should_compact - and len(normalized_record_set["ids"]) > 10 - ): + if not NOT_CLUSTER_ONLY and should_compact and n_records > 10: # Wait for the model to be updated wait_for_version_increase(client, collection.name, initial_version) invariants.count(coll, cast(strategies.RecordSet, normalized_record_set)) - n_results = max(1, (len(normalized_record_set["ids"]) // 10)) + n_results = max(1, (n_records // 10)) if batch_ann_accuracy: batch_size = 10 - for i in range(0, len(normalized_record_set["ids"]), batch_size): + for i in range(0, n_records, batch_size): invariants.ann_accuracy( coll, cast(strategies.RecordSet, normalized_record_set), + n_records=n_records, n_results=n_results, embedding_function=collection.embedding_function, - query_indices=list( - range(i, min(i + batch_size, len(normalized_record_set["ids"]))) - ), + query_indices=list(range(i, min(i + batch_size, n_records))), ) else: invariants.ann_accuracy( coll, cast(strategies.RecordSet, normalized_record_set), + n_records=n_records, n_results=n_results, embedding_function=collection.embedding_function, ) @@ -193,13 +190,12 @@ def test_add_large( metadatas=cast(Metadatas, record_set["metadatas"]), documents=cast(List[str], record_set["documents"]), ): - coll.add(*batch) + results = coll.add(*batch) + if results["ids"] is None: + raise ValueError("IDs should not be None") - if ( - not NOT_CLUSTER_ONLY - and should_compact - and len(normalized_record_set["ids"]) > 10 - ): + n_records = invariants.get_n_items_from_record_set(record_set) + if not NOT_CLUSTER_ONLY and should_compact and n_records > 10: # Wait for the model to be updated, since the record set is larger, add some additional time wait_for_version_increase( client, collection.name, initial_version, additional_time=240 diff --git a/chromadb/test/property/test_cross_version_persist.py b/chromadb/test/property/test_cross_version_persist.py index 4d679f0a424..68d9631fba4 100644 --- a/chromadb/test/property/test_cross_version_persist.py +++ b/chromadb/test/property/test_cross_version_persist.py @@ -98,7 +98,7 @@ def _patch_telemetry_client( def patch_for_version( version: str, collection: strategies.Collection, - embeddings: strategies.RecordSet, + record_set: strategies.RecordSet, settings: Settings, ) -> None: """Override aspects of the collection and embeddings, before testing, to account for @@ -108,7 +108,7 @@ def patch_for_version( if packaging_version.Version(version) <= packaging_version.Version( patch_version ): - patch(collection, embeddings, settings) + patch(collection, record_set, settings) def api_import_for_version(module: Any, version: str) -> Type: # type: ignore @@ -235,7 +235,7 @@ def persist_generated_data_with_old_version( version: str, settings: Settings, collection_strategy: strategies.Collection, - embeddings_strategy: strategies.RecordSet, + record_set: strategies.RecordSet, conn: Connection, ) -> None: try: @@ -257,18 +257,31 @@ def persist_generated_data_with_old_version( # In order to test old versions, we can't rely on the not_implemented function embedding_function=not_implemented_ef(), ) - coll.add(**embeddings_strategy) + result = coll.add(**record_set) + + if ( + packaging_version.Version(version) >= packaging_version.Version("0.5.5") + and record_set["ids"] is None + ): + if result is None: + raise ValueError("IDs from embeddings strategy should not be None") + + if result["ids"] is None: + raise ValueError("IDs from result should not be None") + + record_set["ids"] = result["ids"] # Just use some basic checks for sanity and manual testing where you break the new # version - check_embeddings = invariants.wrap_all(embeddings_strategy) + check_embeddings = invariants.wrap_all(record_set) # Check count - assert coll.count() == len(check_embeddings["embeddings"] or []) + assert coll.count() == len(check_embeddings["embeddings"]) # type: ignore[arg-type] + # Check ids result = coll.get() actual_ids = result["ids"] - embedding_id_to_index = {id: i for i, id in enumerate(check_embeddings["ids"])} + embedding_id_to_index = {id: i for i, id in enumerate(check_embeddings["ids"])} # type: ignore[arg-type] actual_ids = sorted(actual_ids, key=lambda id: embedding_id_to_index[id]) assert actual_ids == check_embeddings["ids"] @@ -301,32 +314,43 @@ def persist_generated_data_with_old_version( @given( collection_strategy=collection_st, - embeddings_strategy=strategies.recordsets(collection_st), + record_set=strategies.recordsets(collection_strategy=collection_st), + should_stomp_ids=st.booleans(), ) @settings(deadline=None) def test_cycle_versions( version_settings: Tuple[str, Settings], collection_strategy: strategies.Collection, - embeddings_strategy: strategies.RecordSet, + record_set: strategies.RecordSet, + should_stomp_ids: bool, ) -> None: # Test backwards compatibility # For the current version, ensure that we can load a collection from # the previous versions version, settings = version_settings + + # TODO: This condition is subject to change as we decide on whether we want to + # release auto ID generation feature after 0.5.5 + + if ( + packaging_version.Version(version) > packaging_version.Version("0.5.5") + and should_stomp_ids + ): + record_set["ids"] = None + # The strategies can generate metadatas of malformed inputs. Other tests # will error check and cover these cases to make sure they error. Here we # just convert them to valid values since the error cases are already tested - if embeddings_strategy["metadatas"] == {}: - embeddings_strategy["metadatas"] = None - if embeddings_strategy["metadatas"] is not None and isinstance( - embeddings_strategy["metadatas"], list + if record_set["metadatas"] == {}: + record_set["metadatas"] = None + if record_set["metadatas"] is not None and isinstance( + record_set["metadatas"], list ): - embeddings_strategy["metadatas"] = [ - m if m is None or len(m) > 0 else None - for m in embeddings_strategy["metadatas"] + record_set["metadatas"] = [ + m if m is None or len(m) > 0 else None for m in record_set["metadatas"] ] - patch_for_version(version, collection_strategy, embeddings_strategy, settings) + patch_for_version(version, collection_strategy, record_set, settings) # Can't pickle a function, and we won't need them collection_strategy.embedding_function = None @@ -339,7 +363,7 @@ def test_cycle_versions( conn1, conn2 = multiprocessing.Pipe() p = ctx.Process( target=persist_generated_data_with_old_version, - args=(version, settings, collection_strategy, embeddings_strategy, conn2), + args=(version, settings, collection_strategy, record_set, conn2), ) p.start() p.join() @@ -390,13 +414,19 @@ def test_cycle_versions( invariants.log_size_below_max(system, [coll], True) # Should be able to add embeddings - coll.add(**embeddings_strategy) # type: ignore - - invariants.count(coll, embeddings_strategy) - invariants.metadatas_match(coll, embeddings_strategy) - invariants.documents_match(coll, embeddings_strategy) - invariants.ids_match(coll, embeddings_strategy) - invariants.ann_accuracy(coll, embeddings_strategy) + result = coll.add(**record_set) # type: ignore[arg-type] + if record_set["ids"] is None: + record_set["ids"] = result["ids"] + + invariants.count(coll, record_set) + invariants.metadatas_match(coll, record_set) + invariants.documents_match(coll, record_set) + invariants.ids_match(coll, record_set) + invariants.ann_accuracy( + coll, + record_set, + n_records=invariants.get_n_items_from_record_set(record_set), + ) invariants.log_size_below_max(system, [coll], True) # Shutdown system diff --git a/chromadb/test/property/test_embeddings.py b/chromadb/test/property/test_embeddings.py index 21bf6c0f4b9..5af6633bdea 100644 --- a/chromadb/test/property/test_embeddings.py +++ b/chromadb/test/property/test_embeddings.py @@ -111,7 +111,7 @@ def teardown(self) -> None: @rule( target=embedding_ids, - record_set=strategies.recordsets(collection_st), + record_set=strategies.recordsets(collection_st, can_ids_be_empty=True), ) def add_embeddings(self, record_set: strategies.RecordSet) -> MultipleResults[ID]: trace("add_embeddings") @@ -121,18 +121,22 @@ def add_embeddings(self, record_set: strategies.RecordSet) -> MultipleResults[ID record_set ) - if len(normalized_record_set["ids"]) > 0: - trace("add_more_embeddings") + ids = normalized_record_set["ids"] + s_ids = self.record_set_state["ids"] - intersection = set(normalized_record_set["ids"]).intersection( - self.record_set_state["ids"] - ) + intersection = set() + # only find possible intersection when + # record_set ids is not None and record_set_state ids is not None + if ids is not None: + intersection = set(ids).intersection(s_ids) + + # if there is an intersection, we need to apply the non-duplicative records to the state if len(intersection) > 0: # Partially apply the non-duplicative records to the state - new_ids = list(set(normalized_record_set["ids"]).difference(intersection)) - indices = [normalized_record_set["ids"].index(id) for id in new_ids] + new_ids = list(set(ids).difference(intersection)) # type: ignore[arg-type] + indices = [ids.index(id) for id in new_ids] # type: ignore[union-attr] filtered_record_set: strategies.NormalizedRecordSet = { - "ids": [normalized_record_set["ids"][i] for i in indices], + "ids": [ids[i] for i in indices], # type: ignore[index] "metadatas": [normalized_record_set["metadatas"][i] for i in indices] if normalized_record_set["metadatas"] else None, @@ -144,19 +148,30 @@ def add_embeddings(self, record_set: strategies.RecordSet) -> MultipleResults[ID else None, } self.collection.add(**normalized_record_set) # type: ignore[arg-type] + self._upsert_embeddings(cast(strategies.RecordSet, filtered_record_set)) - return multiple(*filtered_record_set["ids"]) + return multiple(*filtered_record_set["ids"]) # type: ignore[misc] + # if there is no intersection, we can apply the entire record set to the state else: - self.collection.add(**normalized_record_set) # type: ignore[arg-type] + result = self.collection.add(**normalized_record_set) # type: ignore[arg-type] + + if normalized_record_set["ids"] is None: + normalized_record_set["ids"] = result["ids"] + self._upsert_embeddings(cast(strategies.RecordSet, normalized_record_set)) - return multiple(*normalized_record_set["ids"]) + return multiple(*normalized_record_set["ids"]) # type: ignore[misc] @rule(ids=st.lists(consumes(embedding_ids), min_size=1)) def delete_by_ids(self, ids: IDs) -> None: trace("remove embeddings") self.on_state_change(EmbeddingStateMachineStates.delete_by_ids) - indices_to_remove = [self.record_set_state["ids"].index(id) for id in ids] + + state_ids = self.record_set_state["ids"] + if state_ids is None: + raise ValueError("IDs within the record set state should not be None") + + indices_to_remove = [state_ids.index(id) for id in ids] self.collection.delete(ids=ids) self._remove_embeddings(set(indices_to_remove)) @@ -198,8 +213,9 @@ def upsert_embeddings(self, record_set: strategies.RecordSet) -> None: @invariant() def count(self) -> None: - invariants.count( - self.collection, cast(strategies.RecordSet, self.record_set_state) + invariants.count_state_record_set( + self.collection, + self.record_set_state, ) @invariant() @@ -209,6 +225,9 @@ def no_duplicates(self) -> None: @invariant() def ann_accuracy(self) -> None: invariants.ann_accuracy( + n_records=invariants.get_n_items_from_record_set_state( + self.record_set_state + ), collection=self.collection, record_set=cast(strategies.RecordSet, self.record_set_state), min_recall=0.95, @@ -217,10 +236,15 @@ def ann_accuracy(self) -> None: @invariant() def fields_match(self) -> None: - self.record_set_state = cast(strategies.RecordSet, self.record_set_state) # type: ignore[assignment] - invariants.embeddings_match(self.collection, self.record_set_state) # type: ignore[arg-type] - invariants.metadatas_match(self.collection, self.record_set_state) # type: ignore[arg-type] - invariants.documents_match(self.collection, self.record_set_state) # type: ignore[arg-type] + invariants.embeddings_match_state_record_set( + self.collection, self.record_set_state + ) + invariants.metadatas_match_state_record_set( + self.collection, self.record_set_state + ) + invariants.documents_match_state_record_set( + self.collection, self.record_set_state + ) @precondition( lambda self: is_client_in_process(self.client) @@ -236,6 +260,10 @@ def _upsert_embeddings(self, record_set: strategies.RecordSet) -> None: normalized_record_set: strategies.NormalizedRecordSet = invariants.wrap_all( record_set ) + + if normalized_record_set["ids"] is None: + raise ValueError("IDs should not be empty") + for idx, id in enumerate(normalized_record_set["ids"]): # Update path if id in self.record_set_state["ids"]: @@ -373,23 +401,27 @@ def wait_for_compaction(self) -> None: @rule( target=embedding_ids, - record_set=strategies.recordsets(collection_st), + record_set=strategies.recordsets(collection_st, can_ids_be_empty=True), ) def add_embeddings(self, record_set: strategies.RecordSet) -> MultipleResults[ID]: res = super().add_embeddings(record_set) - normalized_record_set: strategies.NormalizedRecordSet = invariants.wrap_all( - record_set - ) + + n_records = invariants.get_n_items_from_record_set(record_set) + ids = [id for id in res] + print( "[test_embeddings][add] Non Intersection ids ", - normalized_record_set["ids"], + ids, " len ", - len(normalized_record_set["ids"]), + len(ids), ) - self.log_operation_count += len(normalized_record_set["ids"]) - for id in normalized_record_set["ids"]: + + self.log_operation_count += n_records + + for id in res: if id not in self.unique_ids_in_log: self.unique_ids_in_log.add(id) + return res # type: ignore[return-value] @rule(ids=st.lists(consumes(embedding_ids), min_size=1)) @@ -414,13 +446,18 @@ def delete_by_ids(self, ids: IDs) -> None: ) def update_embeddings(self, record_set: strategies.RecordSet) -> None: super().update_embeddings(record_set) + + if record_set["ids"] is None: + raise ValueError("IDs should not be empty") + n = len(invariants.wrap(record_set["ids"])) + print( "[test_embeddings][update] ids ", record_set["ids"], " len ", - len(invariants.wrap(record_set["ids"])), + n, ) - self.log_operation_count += len(invariants.wrap(record_set["ids"])) + self.log_operation_count += n # Using a value < 3 causes more retries and lowers the number of valid samples @precondition(lambda self: len(self.record_set_state["ids"]) >= 3) @@ -434,14 +471,21 @@ def update_embeddings(self, record_set: strategies.RecordSet) -> None: ) def upsert_embeddings(self, record_set: strategies.RecordSet) -> None: super().upsert_embeddings(record_set) + + if record_set["ids"] is None: + raise ValueError("IDs should not be empty") + + ids = invariants.wrap(record_set["ids"]) + n_ids = len(ids) + print( "[test_embeddings][upsert] ids ", record_set["ids"], " len ", - len(invariants.wrap(record_set["ids"])), + n_ids, ) - self.log_operation_count += len(invariants.wrap(record_set["ids"])) - for id in invariants.wrap(record_set["ids"]): + self.log_operation_count += n_ids + for id in ids: if id not in self.unique_ids_in_log: self.unique_ids_in_log.add(id) diff --git a/chromadb/test/property/test_filtering.py b/chromadb/test/property/test_filtering.py index 4ba614ad705..5b3b2cff365 100644 --- a/chromadb/test/property/test_filtering.py +++ b/chromadb/test/property/test_filtering.py @@ -123,7 +123,7 @@ def _filter_embedding_set( """Return IDs from the embedding set that match the given filter object""" normalized_record_set = invariants.wrap_all(record_set) - ids = set(normalized_record_set["ids"]) + ids = set(normalized_record_set["ids"]) # type: ignore[arg-type] filter_ids = filter["ids"] @@ -134,23 +134,23 @@ def _filter_embedding_set( if len(filter_ids) != 0: ids = ids.intersection(filter_ids) - for i in range(len(normalized_record_set["ids"])): + for i in range(len(normalized_record_set["ids"])): # type: ignore[arg-type] if filter["where"]: metadatas: Metadatas if isinstance(normalized_record_set["metadatas"], list): metadatas = normalized_record_set["metadatas"] # type: ignore[assignment] else: - metadatas = [EMPTY_DICT] * len(normalized_record_set["ids"]) + metadatas = [EMPTY_DICT] * len(normalized_record_set["ids"]) # type: ignore[arg-type] filter_where: Where = filter["where"] if not _filter_where_clause(filter_where, metadatas[i]): - ids.discard(normalized_record_set["ids"][i]) + ids.discard(normalized_record_set["ids"][i]) # type: ignore[index] if filter["where_document"]: documents = normalized_record_set["documents"] or [EMPTY_STRING] * len( - normalized_record_set["ids"] + normalized_record_set["ids"] # type: ignore[arg-type] ) if not _filter_where_doc_clause(filter["where_document"], documents[i]): - ids.discard(normalized_record_set["ids"][i]) + ids.discard(normalized_record_set["ids"][i]) # type: ignore[index] return list(ids) @@ -160,7 +160,8 @@ def _filter_embedding_set( key="coll", ) recordset_st = st.shared( - strategies.recordsets(collection_st, max_size=1000), key="recordset" + strategies.recordsets(collection_st, max_size=1000), + key="recordset", ) @@ -195,7 +196,10 @@ def test_filterable_metadata_get( embedding_function=collection.embedding_function, ) - initial_version = coll.get_model()["version"] + initial_version = cast(int, coll.get_model()["version"]) + + if record_set["ids"] is None: + raise ValueError("Record set IDs cannot be None") coll.add(**record_set) @@ -308,7 +312,7 @@ def test_filterable_metadata_query( metadata=collection.metadata, # type: ignore embedding_function=collection.embedding_function, ) - initial_version = coll.get_model()["version"] + initial_version = cast(int, coll.get_model()["version"]) normalized_record_set = invariants.wrap_all(record_set) coll.add(**record_set) # type: ignore[arg-type] @@ -316,11 +320,11 @@ def test_filterable_metadata_query( if not NOT_CLUSTER_ONLY: # Only wait for compaction if the size of the collection is # some minimal size - if should_compact and len(invariants.wrap(record_set["ids"])) > 10: + if should_compact and len(invariants.wrap(record_set["ids"])) > 10: # type: ignore[arg-type] # Wait for the model to be updated wait_for_version_increase(client, collection.name, initial_version) # type: ignore - total_count = len(normalized_record_set["ids"]) + total_count = len(normalized_record_set["ids"]) # type: ignore[arg-type] # Pick a random vector random_query: Embedding if collection.has_embeddings: diff --git a/chromadb/test/property/test_persist.py b/chromadb/test/property/test_persist.py index 92f65b27714..a4a1f4272a6 100644 --- a/chromadb/test/property/test_persist.py +++ b/chromadb/test/property/test_persist.py @@ -75,7 +75,7 @@ def settings(request: pytest.FixtureRequest) -> Generator[Settings, None, None]: @given( collection_strategy=collection_st, - embeddings_strategy=strategies.recordsets(collection_st), + embeddings_strategy=strategies.recordsets(collection_st, can_ids_be_empty=True), ) def test_persist( settings: Settings, @@ -93,7 +93,9 @@ def test_persist( embedding_function=collection_strategy.embedding_function, ) - coll.add(**embeddings_strategy) # type: ignore[arg-type] + result = coll.add(**embeddings_strategy) # type: ignore[arg-type] + if embeddings_strategy["ids"] is None: + embeddings_strategy["ids"] = result["ids"] invariants.count(coll, embeddings_strategy) invariants.metadatas_match(coll, embeddings_strategy) @@ -103,6 +105,7 @@ def test_persist( coll, embeddings_strategy, embedding_function=collection_strategy.embedding_function, + n_records=invariants.get_n_items_from_record_set(embeddings_strategy), ) system_1.stop() @@ -125,6 +128,7 @@ def test_persist( coll, embeddings_strategy, embedding_function=collection_strategy.embedding_function, + n_records=invariants.get_n_items_from_record_set(embeddings_strategy), ) system_2.stop() @@ -194,7 +198,7 @@ def get_index_last_modified_at() -> float: def load_and_check( settings: Settings, collection_name: str, - record_set: strategies.RecordSet, + record_set: strategies.StateMachineRecordSet, conn: Connection, ) -> None: try: @@ -206,11 +210,11 @@ def load_and_check( name=collection_name, embedding_function=strategies.not_implemented_embedding_function(), # type: ignore[arg-type] ) - invariants.count(coll, record_set) - invariants.metadatas_match(coll, record_set) - invariants.documents_match(coll, record_set) - invariants.ids_match(coll, record_set) - invariants.ann_accuracy(coll, record_set) + invariants.count_state_record_set(coll, record_set) + invariants.metadatas_match_state_record_set(coll, record_set) + invariants.documents_match_state_record_set(coll, record_set) + invariants.ids_match(coll, record_set) # type: ignore[arg-type] + invariants.ann_accuracy(coll, record_set, n_records=invariants.get_n_items_from_record_set_state(record_set)) # type: ignore[arg-type] system.stop() except Exception as e: diff --git a/chromadb/test/test_multithreaded.py b/chromadb/test/test_multithreaded.py index 745f562f4b0..a34f9acfbcb 100644 --- a/chromadb/test/test_multithreaded.py +++ b/chromadb/test/test_multithreaded.py @@ -55,17 +55,17 @@ def _test_multithreaded_add( with ThreadPoolExecutor(max_workers=num_workers) as executor: futures: List[Future[Any]] = [] total_sent = -1 - while total_sent < len(ids): + while total_sent < len(ids): # type: ignore[arg-type] # Randomly grab up to 10% of the dataset and send it to the executor batch_size = random.randint(1, N // 10) - to_send = min(batch_size, len(ids) - total_sent) + to_send = min(batch_size, len(ids) - total_sent) # type: ignore[arg-type] start = total_sent + 1 end = total_sent + to_send + 1 if embeddings is not None and len(embeddings[start:end]) == 0: break future = executor.submit( coll.add, - ids=ids[start:end], + ids=ids[start:end], # type: ignore[index] embeddings=embeddings[start:end] if embeddings is not None else None, metadatas=metadatas[start:end] if metadatas is not None else None, # type: ignore documents=documents[start:end] if documents is not None else None, @@ -93,6 +93,7 @@ def _test_multithreaded_add( invariants.ann_accuracy( coll, records_set, + n_records=invariants.get_n_items_from_record_set(records_set), n_results=n_results, query_indices=query_indices, ) @@ -210,6 +211,7 @@ def perform_operation( invariants.ann_accuracy( coll, records_set, + n_records=invariants.get_n_items_from_record_set(records_set), n_results=n_results, query_indices=query_indices, )