diff --git a/chromadb/api/models/CollectionCommon.py b/chromadb/api/models/CollectionCommon.py index e1c310779c98..d55b80d69a08 100644 --- a/chromadb/api/models/CollectionCommon.py +++ b/chromadb/api/models/CollectionCommon.py @@ -41,13 +41,13 @@ validate_ids, validate_include, validate_metadata, - validate_metadatas, validate_embeddings, validate_embedding_function, validate_n_results, validate_where, validate_where_document, record_set_contains_one_of, + validate_record_set, ) # TODO: We should rename the types in chromadb.types to be Models where @@ -169,29 +169,6 @@ def _unpack_record_set( "uris": maybe_cast_one_to_many(uris), } - @staticmethod - def _validate_record_set( - record_set: RecordSet, - require_data: bool, - ) -> None: - # Only one of documents or images can be provided - if record_set["documents"] is not None and record_set["images"] is not None: - raise ValueError("You can only provide documents or images, not both.") - - required_fields: Include = ["embeddings", "documents", "images", "uris"] # type: ignore[list-item] - if not require_data: - required_fields += ["metadatas"] # type: ignore[list-item] - - if not record_set_contains_one_of(record_set, include=required_fields): - raise ValueError(f"You must provide one of {', '.join(required_fields)}") - - valid_ids = record_set["ids"] - for key in ["embeddings", "metadatas", "documents", "images", "uris"]: - if record_set[key] is not None and len(record_set[key]) != len(valid_ids): # type: ignore[literal-required] - raise ValueError( - f"Number of {key} {len(record_set[key])} must match number of ids {len(valid_ids)}" # type: ignore[literal-required] - ) - def _compute_embeddings( self, documents: Optional[Documents], @@ -398,7 +375,7 @@ def _process_add_request( uris=uris, ) - self._validate_record_set( + validate_record_set( record_set, require_data=True, ) @@ -435,7 +412,7 @@ def _process_upsert_request( uris=uris, ) - self._validate_record_set( + validate_record_set( record_set, require_data=True, ) @@ -473,7 +450,7 @@ def _process_update_request( uris=uris, ) - self._validate_record_set( + validate_record_set( record_set, require_data=False, ) diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index 1015fc6b5b26..d6b2b3bf0033 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -638,7 +638,6 @@ def _delete( """ ) - coll = self._get_collection(collection_id) request_version_context = t.RequestVersionContext( collection_version=coll.version, log_position=coll.log_position, @@ -660,19 +659,6 @@ def _delete( if len(ids_to_delete) == 0: return [] - self._validate_record_set( - collection=coll, - record_set={ - "ids": ids_to_delete, - "embeddings": None, - "documents": None, - "uris": None, - "metadatas": None, - "images": None, - }, - require_data=False, - ) - records_to_submit = list( _records(operation=t.Operation.DELETE, ids=ids_to_delete) ) @@ -906,7 +892,7 @@ def _validate_record_set( add_attributes_to_current_span({"collection_id": str(collection["id"])}) try: - validate_record_set(record_set) + validate_record_set(record_set, require_data=require_data) validate_batch( ( record_set["ids"], diff --git a/chromadb/api/types.py b/chromadb/api/types.py index cda766274309..23182f8c6e0a 100644 --- a/chromadb/api/types.py +++ b/chromadb/api/types.py @@ -616,26 +616,35 @@ def validate_batch( ) -def validate_record_set(record_set: RecordSet) -> None: - embeddings = record_set["embeddings"] - ids = record_set["ids"] - metadatas = record_set["metadatas"] - - validate_ids(ids) - validate_embeddings(embeddings) if embeddings is not None else None - validate_metadatas(metadatas) if metadatas is not None else None - - for field, value in record_set.items(): - if field == "ids" or value is None: - continue - - if isinstance(value, list): - n = len(value) - n_ids = len(record_set["ids"]) - if n != n_ids: - raise ValueError( - f"Number of {field} ({n}) does not match number of ids ({n_ids})" - ) +def validate_record_set( + record_set: RecordSet, + require_data: bool, +) -> None: + validate_ids(record_set["ids"]) + validate_embeddings(record_set["embeddings"]) if record_set[ + "embeddings" + ] is not None else None + validate_metadatas(record_set["metadatas"]) if record_set[ + "metadatas" + ] is not None else None + + # Only one of documents or images can be provided + if record_set["documents"] is not None and record_set["images"] is not None: + raise ValueError("You can only provide documents or images, not both.") + + required_fields: Include = ["embeddings", "documents", "images", "uris"] # type: ignore[list-item] + if not require_data: + required_fields += ["metadatas"] # type: ignore[list-item] + + if not record_set_contains_one_of(record_set, include=required_fields): + raise ValueError(f"You must provide one of {', '.join(required_fields)}") + + valid_ids = record_set["ids"] + for key in ["embeddings", "metadatas", "documents", "images", "uris"]: + if record_set[key] is not None and len(record_set[key]) != len(valid_ids): # type: ignore[literal-required] + raise ValueError( + f"Number of {key} {len(record_set[key])} must match number of ids {len(valid_ids)}" # type: ignore[literal-required] + ) def convert_np_embeddings_to_list(embeddings: Embeddings) -> PyEmbeddings: