Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
spikechroma authored and atroyn committed Oct 2, 2024
1 parent c7c8d38 commit cea408e
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 62 deletions.
31 changes: 4 additions & 27 deletions chromadb/api/models/CollectionCommon.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@
validate_ids,
validate_include,
validate_metadata,
validate_metadatas,
validate_embeddings,
validate_embedding_function,
validate_n_results,
validate_where,
validate_where_document,
record_set_contains_one_of,
validate_record_set,
)

# TODO: We should rename the types in chromadb.types to be Models where
Expand Down Expand Up @@ -169,29 +169,6 @@ def _unpack_record_set(
"uris": maybe_cast_one_to_many(uris),
}

@staticmethod
def _validate_record_set(
record_set: RecordSet,
require_data: bool,
) -> None:
# Only one of documents or images can be provided
if record_set["documents"] is not None and record_set["images"] is not None:
raise ValueError("You can only provide documents or images, not both.")

required_fields: Include = ["embeddings", "documents", "images", "uris"] # type: ignore[list-item]
if not require_data:
required_fields += ["metadatas"] # type: ignore[list-item]

if not record_set_contains_one_of(record_set, include=required_fields):
raise ValueError(f"You must provide one of {', '.join(required_fields)}")

valid_ids = record_set["ids"]
for key in ["embeddings", "metadatas", "documents", "images", "uris"]:
if record_set[key] is not None and len(record_set[key]) != len(valid_ids): # type: ignore[literal-required]
raise ValueError(
f"Number of {key} {len(record_set[key])} must match number of ids {len(valid_ids)}" # type: ignore[literal-required]
)

def _compute_embeddings(
self,
documents: Optional[Documents],
Expand Down Expand Up @@ -398,7 +375,7 @@ def _process_add_request(
uris=uris,
)

self._validate_record_set(
validate_record_set(
record_set,
require_data=True,
)
Expand Down Expand Up @@ -435,7 +412,7 @@ def _process_upsert_request(
uris=uris,
)

self._validate_record_set(
validate_record_set(
record_set,
require_data=True,
)
Expand Down Expand Up @@ -473,7 +450,7 @@ def _process_update_request(
uris=uris,
)

self._validate_record_set(
validate_record_set(
record_set,
require_data=False,
)
Expand Down
16 changes: 1 addition & 15 deletions chromadb/api/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,6 @@ def _delete(
"""
)

coll = self._get_collection(collection_id)
request_version_context = t.RequestVersionContext(
collection_version=coll.version,
log_position=coll.log_position,
Expand All @@ -660,19 +659,6 @@ def _delete(
if len(ids_to_delete) == 0:
return []

self._validate_record_set(
collection=coll,
record_set={
"ids": ids_to_delete,
"embeddings": None,
"documents": None,
"uris": None,
"metadatas": None,
"images": None,
},
require_data=False,
)

records_to_submit = list(
_records(operation=t.Operation.DELETE, ids=ids_to_delete)
)
Expand Down Expand Up @@ -906,7 +892,7 @@ def _validate_record_set(
add_attributes_to_current_span({"collection_id": str(collection["id"])})

try:
validate_record_set(record_set)
validate_record_set(record_set, require_data=require_data)
validate_batch(
(
record_set["ids"],
Expand Down
49 changes: 29 additions & 20 deletions chromadb/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,26 +616,35 @@ def validate_batch(
)


def validate_record_set(record_set: RecordSet) -> None:
embeddings = record_set["embeddings"]
ids = record_set["ids"]
metadatas = record_set["metadatas"]

validate_ids(ids)
validate_embeddings(embeddings) if embeddings is not None else None
validate_metadatas(metadatas) if metadatas is not None else None

for field, value in record_set.items():
if field == "ids" or value is None:
continue

if isinstance(value, list):
n = len(value)
n_ids = len(record_set["ids"])
if n != n_ids:
raise ValueError(
f"Number of {field} ({n}) does not match number of ids ({n_ids})"
)
def validate_record_set(
record_set: RecordSet,
require_data: bool,
) -> None:
validate_ids(record_set["ids"])
validate_embeddings(record_set["embeddings"]) if record_set[
"embeddings"
] is not None else None
validate_metadatas(record_set["metadatas"]) if record_set[
"metadatas"
] is not None else None

# Only one of documents or images can be provided
if record_set["documents"] is not None and record_set["images"] is not None:
raise ValueError("You can only provide documents or images, not both.")

required_fields: Include = ["embeddings", "documents", "images", "uris"] # type: ignore[list-item]
if not require_data:
required_fields += ["metadatas"] # type: ignore[list-item]

if not record_set_contains_one_of(record_set, include=required_fields):
raise ValueError(f"You must provide one of {', '.join(required_fields)}")

valid_ids = record_set["ids"]
for key in ["embeddings", "metadatas", "documents", "images", "uris"]:
if record_set[key] is not None and len(record_set[key]) != len(valid_ids): # type: ignore[literal-required]
raise ValueError(
f"Number of {key} {len(record_set[key])} must match number of ids {len(valid_ids)}" # type: ignore[literal-required]
)


def convert_np_embeddings_to_list(embeddings: Embeddings) -> PyEmbeddings:
Expand Down

0 comments on commit cea408e

Please sign in to comment.