diff --git a/chromadb/api/fastapi.py b/chromadb/api/fastapi.py index a814cfcc6f9..f71720a1152 100644 --- a/chromadb/api/fastapi.py +++ b/chromadb/api/fastapi.py @@ -25,8 +25,8 @@ GetResult, QueryResult, CollectionMetadata, - validate_batch, convert_np_embeddings_to_list, + validate_batch, ) from chromadb.auth import ( ClientAuthProvider, diff --git a/chromadb/api/models/CollectionCommon.py b/chromadb/api/models/CollectionCommon.py index 394c18b96cf..d55b80d69a0 100644 --- a/chromadb/api/models/CollectionCommon.py +++ b/chromadb/api/models/CollectionCommon.py @@ -41,13 +41,13 @@ validate_ids, validate_include, validate_metadata, - validate_metadatas, validate_embeddings, validate_embedding_function, validate_n_results, validate_where, validate_where_document, record_set_contains_one_of, + validate_record_set, ) # TODO: We should rename the types in chromadb.types to be Models where @@ -169,37 +169,6 @@ def _unpack_record_set( "uris": maybe_cast_one_to_many(uris), } - @staticmethod - def _validate_record_set( - record_set: RecordSet, - require_data: bool, - ) -> None: - validate_ids(record_set["ids"]) - validate_embeddings(record_set["embeddings"]) if record_set[ - "embeddings" - ] is not None else None - validate_metadatas(record_set["metadatas"]) if record_set[ - "metadatas" - ] is not None else None - - # Only one of documents or images can be provided - if record_set["documents"] is not None and record_set["images"] is not None: - raise ValueError("You can only provide documents or images, not both.") - - required_fields: Include = ["embeddings", "documents", "images", "uris"] # type: ignore[list-item] - if not require_data: - required_fields += ["metadatas"] # type: ignore[list-item] - - if not record_set_contains_one_of(record_set, include=required_fields): - raise ValueError(f"You must provide one of {', '.join(required_fields)}") - - valid_ids = record_set["ids"] - for key in ["embeddings", "metadatas", "documents", "images", "uris"]: - if record_set[key] is not None and len(record_set[key]) != len(valid_ids): # type: ignore[literal-required] - raise ValueError( - f"Number of {key} {len(record_set[key])} must match number of ids {len(valid_ids)}" # type: ignore[literal-required] - ) - def _compute_embeddings( self, documents: Optional[Documents], @@ -406,7 +375,7 @@ def _process_add_request( uris=uris, ) - self._validate_record_set( + validate_record_set( record_set, require_data=True, ) @@ -443,7 +412,7 @@ def _process_upsert_request( uris=uris, ) - self._validate_record_set( + validate_record_set( record_set, require_data=True, ) @@ -481,7 +450,7 @@ def _process_update_request( uris=uris, ) - self._validate_record_set( + validate_record_set( record_set, require_data=False, ) diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index ca4d7a01644..70243a80517 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -19,6 +19,7 @@ from chromadb.errors import ( InvalidDimensionException, InvalidCollectionException, + InvalidInputException, VersionMismatchError, ) from chromadb.api.types import ( @@ -34,6 +35,7 @@ Where, WhereDocument, Include, + RecordSet, GetResult, QueryResult, validate_metadata, @@ -41,6 +43,7 @@ validate_where, validate_where_document, validate_batch, + validate_record_set, ) from chromadb.telemetry.product.events import ( CollectionAddEvent, @@ -341,10 +344,20 @@ def _add( self._quota.static_check(metadatas, documents, embeddings, str(collection_id)) coll = self._get_collection(collection_id) self._manager.hint_use_collection(collection_id, t.Operation.ADD) - validate_batch( - (ids, embeddings, metadatas, documents, uris), - {"max_batch_size": self.get_max_batch_size()}, + + self._validate_record_set( + collection=coll, + record_set={ + "ids": ids, + "embeddings": embeddings, + "documents": documents, + "uris": uris, + "metadatas": metadatas, + "images": None, + }, + require_data=True, ) + records_to_submit = list( _records( t.Operation.ADD, @@ -355,7 +368,6 @@ def _add( uris=uris, ) ) - self._validate_embedding_record_set(coll, records_to_submit) self._producer.submit_embeddings(collection_id, records_to_submit) self._product_telemetry_client.capture( @@ -383,10 +395,20 @@ def _update( self._quota.static_check(metadatas, documents, embeddings, str(collection_id)) coll = self._get_collection(collection_id) self._manager.hint_use_collection(collection_id, t.Operation.UPDATE) - validate_batch( - (ids, embeddings, metadatas, documents, uris), - {"max_batch_size": self.get_max_batch_size()}, + + self._validate_record_set( + collection=coll, + record_set={ + "ids": ids, + "embeddings": embeddings, + "documents": documents, + "uris": uris, + "metadatas": metadatas, + "images": None, + }, + require_data=False, ) + records_to_submit = list( _records( t.Operation.UPDATE, @@ -397,7 +419,6 @@ def _update( uris=uris, ) ) - self._validate_embedding_record_set(coll, records_to_submit) self._producer.submit_embeddings(collection_id, records_to_submit) self._product_telemetry_client.capture( @@ -427,10 +448,20 @@ def _upsert( self._quota.static_check(metadatas, documents, embeddings, str(collection_id)) coll = self._get_collection(collection_id) self._manager.hint_use_collection(collection_id, t.Operation.UPSERT) - validate_batch( - (ids, embeddings, metadatas, documents, uris), - {"max_batch_size": self.get_max_batch_size()}, + + self._validate_record_set( + collection=coll, + record_set={ + "ids": ids, + "embeddings": embeddings, + "documents": documents, + "uris": uris, + "metadatas": metadatas, + "images": None, + }, + require_data=True, ) + records_to_submit = list( _records( t.Operation.UPSERT, @@ -441,7 +472,6 @@ def _upsert( uris=uris, ) ) - self._validate_embedding_record_set(coll, records_to_submit) self._producer.submit_embeddings(collection_id, records_to_submit) return True @@ -630,7 +660,6 @@ def _delete( records_to_submit = list( _records(operation=t.Operation.DELETE, ids=ids_to_delete) ) - self._validate_embedding_record_set(coll, records_to_submit) self._producer.submit_embeddings(collection_id, records_to_submit) self._product_telemetry_client.capture( @@ -851,19 +880,41 @@ def get_max_batch_size(self) -> int: # system, since the cache is only local. # TODO: promote collection -> topic to a base class method so that it can be # used for channel assignment in the distributed version of the system. - @trace_method( - "SegmentAPI._validate_embedding_record_set", OpenTelemetryGranularity.ALL - ) - def _validate_embedding_record_set( - self, collection: t.Collection, records: List[t.OperationRecord] + @trace_method("SegmentAPI._validate_record_set", OpenTelemetryGranularity.ALL) + def _validate_record_set( + self, + collection: t.Collection, + record_set: RecordSet, + require_data: bool, ) -> None: - """Validate the dimension of an embedding record before submitting it to the system.""" add_attributes_to_current_span({"collection_id": str(collection["id"])}) - for record in records: - if record["embedding"] is not None: - self._validate_dimension( - collection, len(record["embedding"]), update=True - ) + + try: + validate_record_set(record_set, require_data=require_data) + validate_batch( + ( + record_set["ids"], + record_set["embeddings"], + record_set["metadatas"], + record_set["documents"], + record_set["uris"], + ), + {"max_batch_size": self.get_max_batch_size()}, + ) + + if require_data and record_set["embeddings"] is None: + raise ValueError("You must provide embeddings") + + if record_set["embeddings"] is not None: + """Validate the dimension of an embedding record before submitting it to the system.""" + for embedding in record_set["embeddings"]: + if embedding: + self._validate_dimension( + collection, len(embedding), update=True + ) + + except ValueError as e: + raise InvalidInputException(f"{e}") # This method is intentionally left untraced because otherwise it can emit thousands of spans for requests containing many embeddings. def _validate_dimension( diff --git a/chromadb/api/types.py b/chromadb/api/types.py index 87a0c60a11f..58104434b66 100644 --- a/chromadb/api/types.py +++ b/chromadb/api/types.py @@ -57,7 +57,7 @@ def maybe_cast_one_to_many(target: Optional[(OneOrMany[T])]) -> Optional[List[T] def maybe_cast_one_to_many_embedding( - target: Union[OneOrMany[Embedding], OneOrMany[PyEmbedding]] + target: Union[Optional[OneOrMany[Embedding]], Optional[OneOrMany[PyEmbedding]]] ) -> Optional[Embeddings]: if target is None: return None @@ -616,6 +616,37 @@ def validate_batch( ) +def validate_record_set( + record_set: RecordSet, + require_data: bool, +) -> None: + validate_ids(record_set["ids"]) + validate_embeddings(record_set["embeddings"]) if record_set[ + "embeddings" + ] is not None else None + validate_metadatas(record_set["metadatas"]) if record_set[ + "metadatas" + ] is not None else None + + # Only one of documents or images can be provided + if record_set["documents"] is not None and record_set["images"] is not None: + raise ValueError("You can only provide documents or images, not both.") + + required_fields: Include = ["embeddings", "documents", "images", "uris"] # type: ignore[list-item] + if not require_data: + required_fields += ["metadatas"] # type: ignore[list-item] + + if not record_set_contains_one_of(record_set, include=required_fields): + raise ValueError(f"You must provide one of {', '.join(required_fields)}") + + valid_ids = record_set["ids"] + for key in ["embeddings", "metadatas", "documents", "images", "uris"]: + if record_set[key] is not None and len(record_set[key]) != len(valid_ids): # type: ignore[literal-required] + raise ValueError( + f"Number of {key} {len(record_set[key])} must match number of ids {len(valid_ids)}" # type: ignore[literal-required] + ) + + def convert_np_embeddings_to_list(embeddings: Embeddings) -> PyEmbeddings: return [embedding.tolist() for embedding in embeddings] diff --git a/chromadb/errors.py b/chromadb/errors.py index 18e95fb7c44..99bc1bc8970 100644 --- a/chromadb/errors.py +++ b/chromadb/errors.py @@ -20,6 +20,13 @@ def name(cls) -> str: pass +class InvalidInputException(ChromaError): + @classmethod + @overrides + def name(cls) -> str: + return "InvalidInput" + + class InvalidDimensionException(ChromaError): @classmethod @overrides @@ -137,6 +144,7 @@ def name(cls) -> str: error_types: Dict[str, Type[ChromaError]] = { + "InvalidInput": InvalidInputException, "InvalidDimension": InvalidDimensionException, "InvalidCollection": InvalidCollectionException, "IDAlreadyExists": IDAlreadyExistsError, diff --git a/chromadb/test/api/test_validations.py b/chromadb/test/api/test_validations.py index 3264470c68f..c35499864a6 100644 --- a/chromadb/test/api/test_validations.py +++ b/chromadb/test/api/test_validations.py @@ -4,6 +4,8 @@ RecordSet, record_set_contains_one_of, maybe_cast_one_to_many_embedding, + validate_embeddings, + Embeddings, ) @@ -53,13 +55,13 @@ def test_maybe_cast_one_to_many_embedding() -> None: assert maybe_cast_one_to_many_embedding(None) is None # Test with a single embedding as a list - single_embedding = [1.0, 2.0, 3.0] + single_embedding = np.array([1.0, 2.0, 3.0]) result = maybe_cast_one_to_many_embedding(single_embedding) assert result == [single_embedding] # Test with multiple embeddings as a list of lists - multiple_embeddings = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] - result = maybe_cast_one_to_many_embedding(multiple_embeddings) # type: ignore[arg-type] + multiple_embeddings = [np.array([1.0, 2.0, 3.0]), np.array([4.0, 5.0, 6.0])] + result = maybe_cast_one_to_many_embedding(multiple_embeddings) assert result == multiple_embeddings # Test with a numpy array (single embedding) @@ -96,3 +98,31 @@ def test_maybe_cast_one_to_many_embedding() -> None: match="Expected embeddings to be a list or a numpy array, got str", ): maybe_cast_one_to_many_embedding("") # type: ignore[arg-type] + + +def test_embeddings_validation() -> None: + invalid_embeddings = [[0, 0, True], [1.2, 2.24, 3.2]] + + with pytest.raises(ValueError) as e: + validate_embeddings(invalid_embeddings) # type: ignore[arg-type] + + assert "Expected each value in the embedding to be a int or float" in str(e) + + invalid_embeddings = [[0, 0, "invalid"], [1.2, 2.24, 3.2]] + + with pytest.raises(ValueError) as e: + validate_embeddings(invalid_embeddings) # type: ignore[arg-type] + + assert "Expected each value in the embedding to be a int or float" in str(e) + + with pytest.raises(ValueError) as e: + validate_embeddings("invalid") # type: ignore[arg-type] + + assert "Expected embeddings to be a list, got str" in str(e) + + +def test_0dim_embedding_validation() -> None: + embds: Embeddings = [[]] # type: ignore[list-item] + with pytest.raises(ValueError) as e: + validate_embeddings(embds) + assert "Expected each embedding in the embeddings to be a non-empty list" in str(e) diff --git a/chromadb/test/property/test_embeddings.py b/chromadb/test/property/test_embeddings.py index 50d9fd9a820..21bf6c0f4b9 100644 --- a/chromadb/test/property/test_embeddings.py +++ b/chromadb/test/property/test_embeddings.py @@ -10,9 +10,9 @@ from dataclasses import dataclass from chromadb.api.types import ( ID, - Embeddings, Include, IDs, + Embeddings, validate_embeddings, maybe_cast_one_to_many_embedding, ) @@ -803,7 +803,6 @@ def test_autocasting_validate_embeddings_for_compatible_types( supported_types: List[Any], ) -> None: embds = strategies.create_embeddings(10, 10, supported_types) - validated_embeddings = validate_embeddings(maybe_cast_one_to_many_embedding(embds)) # type: ignore[arg-type] assert all( [ diff --git a/clients/js/src/ChromaFetch.ts b/clients/js/src/ChromaFetch.ts index ebd98cc15a9..da9a9045ec6 100644 --- a/clients/js/src/ChromaFetch.ts +++ b/clients/js/src/ChromaFetch.ts @@ -59,7 +59,7 @@ export const chromaFetch: FetchAPI = async ( switch (resp.status) { case 400: throw new ChromaClientError( - `Bad request to ${input} with status: ${resp.statusText}`, + `Bad request to ${input} with message: ${respBody?.message}`, ); case 401: throw new ChromaUnauthorizedError(`Unauthorized`); diff --git a/clients/js/test/add.collections.test.ts b/clients/js/test/add.collections.test.ts index 76d0e9e203c..8f836e9d404 100644 --- a/clients/js/test/add.collections.test.ts +++ b/clients/js/test/add.collections.test.ts @@ -133,11 +133,9 @@ describe("add collections", () => { const ids = IDS.concat(["test1"]); const embeddings = EMBEDDINGS.concat([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]); const metadatas = METADATAS.concat([{ test: "test1", float_value: 0.1 }]); - try { + expect(async () => { await collection.add({ ids, embeddings, metadatas }); - } catch (e: any) { - expect(e.message).toMatch("duplicates"); - } + }).rejects.toThrow("found duplicates"); }); test("should error on empty embedding", async () => { @@ -145,11 +143,9 @@ describe("add collections", () => { const ids = ["id1"]; const embeddings = [[]]; const metadatas = [{ test: "test1", float_value: 0.1 }]; - try { + expect(async () => { await collection.add({ ids, embeddings, metadatas }); - } catch (e: any) { - expect(e.message).toMatch("got empty embedding at pos"); - } + }).rejects.toThrow("got empty embedding at pos"); }); if (!process.env.OLLAMA_SERVER_URL) {