Skip to content

Commit

Permalink
Cleanup & Organization
Browse files Browse the repository at this point in the history
  • Loading branch information
spikechroma authored and atroyn committed Oct 2, 2024
1 parent d2dcb57 commit b726474
Show file tree
Hide file tree
Showing 13 changed files with 140 additions and 169 deletions.
2 changes: 2 additions & 0 deletions chromadb/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,8 @@ def _add(

validate_batch_size(record_set, {"max_batch_size": self.get_max_batch_size()})

# This differs from the request for update and upsert because we want to return the ids,
# which are generated at the server (segment)
resp_json = self._make_request(
"post",
"/collections/" + str(collection_id) + "/add",
Expand Down
2 changes: 2 additions & 0 deletions chromadb/api/models/AsyncCollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ async def update(

await self._client._update(
collection_id=self.id,
# TODO: We slightly abuse the RecordSet type here because on the type IDs could be None
ids=cast(IDs, record_set["ids"]),
embeddings=cast(Embeddings, record_set["embeddings"]),
metadatas=record_set["metadatas"],
Expand Down Expand Up @@ -318,6 +319,7 @@ async def upsert(

await self._client._upsert(
collection_id=self.id,
# TODO: We slightly abuse the RecordSet type here because on the type IDs could be None
ids=cast(IDs, record_set["ids"]),
embeddings=cast(Embeddings, record_set["embeddings"]),
metadatas=record_set["metadatas"],
Expand Down
2 changes: 2 additions & 0 deletions chromadb/api/models/Collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ def update(

self._client._update(
collection_id=self.id,
# TODO: We slightly abuse the RecordSet type here because on the type IDs could be None
ids=cast(IDs, record_set["ids"]),
embeddings=cast(Embeddings, record_set["embeddings"]),
metadatas=record_set["metadatas"],
Expand Down Expand Up @@ -316,6 +317,7 @@ def upsert(

self._client._upsert(
collection_id=self.id,
# TODO: We slightly abuse the RecordSet type here because on the type IDs could be None
ids=cast(IDs, record_set["ids"]),
embeddings=cast(Embeddings, record_set["embeddings"]),
metadatas=record_set["metadatas"],
Expand Down
2 changes: 1 addition & 1 deletion chromadb/api/models/CollectionCommon.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def _unpack_record_set(
uris: Optional[OneOrMany[URI]] = None,
) -> RecordSet:
return {
"ids": cast(IDs, maybe_cast_one_to_many(ids)),
"ids": maybe_cast_one_to_many(ids),
"embeddings": maybe_cast_one_to_many_embedding(embeddings),
"metadatas": maybe_cast_one_to_many(metadatas),
"documents": maybe_cast_one_to_many(documents),
Expand Down
45 changes: 19 additions & 26 deletions chromadb/api/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
validate_where_document,
validate_batch_size,
validate_record_set,
get_n_items_from_record_set,
count_records,
)
from chromadb.telemetry.product.events import (
CollectionAddEvent,
Expand Down Expand Up @@ -347,6 +347,7 @@ def _add(
coll = self._get_collection(collection_id)
self._manager.hint_use_collection(collection_id, t.Operation.ADD)

# TODO: We slightly abuse the record_set type here
record_set: RecordSet = {
"ids": ids,
"embeddings": embeddings,
Expand All @@ -355,9 +356,12 @@ def _add(
"metadatas": metadatas,
"images": None,
}
n = count_records(record_set)

ids = self.generate_ids_when_not_present(record_set)
record_set["ids"] = ids
# Generate ids if not provided
if record_set["ids"] is None:
n = count_records(record_set)
record_set["ids"] = [str(uuid4()) for _ in range(n)]

self._validate_record_set(
collection=coll,
Expand All @@ -368,25 +372,27 @@ def _add(
records_to_submit = list(
_records(
t.Operation.ADD,
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
uris=uris,
ids=cast(IDs, record_set["ids"]), # IDs validator checks for None
embeddings=record_set["embeddings"],
metadatas=record_set["metadatas"],
documents=record_set["documents"],
uris=record_set["uris"],
)
)
self._producer.submit_embeddings(collection_id, records_to_submit)

self._product_telemetry_client.capture(
CollectionAddEvent(
collection_uuid=str(collection_id),
add_amount=len(ids),
with_metadata=len(ids) if metadatas is not None else 0,
with_documents=len(ids) if documents is not None else 0,
with_uris=len(ids) if uris is not None else 0,
add_amount=n,
with_metadata=n if metadatas is not None else 0,
with_documents=n if documents is not None else 0,
with_uris=n if uris is not None else 0,
)
)
return AddResult(ids=ids)
return AddResult(
ids=cast(IDs, record_set["ids"])
) # IDs validator checks for None

@trace_method("SegmentAPI._update", OpenTelemetryGranularity.OPERATION)
@override
Expand Down Expand Up @@ -883,19 +889,6 @@ def get_settings(self) -> Settings:
def get_max_batch_size(self) -> int:
return self._producer.max_batch_size

@staticmethod
def generate_ids_when_not_present(
record_set: RecordSet,
) -> IDs:
ids = record_set["ids"]
if ids is not None:
return ids

n = get_n_items_from_record_set(record_set)
generated_ids: List[str] = [str(uuid4()) for _ in range(n)]

return generated_ids

# TODO: This could potentially cause race conditions in a distributed version of the
# system, since the cache is only local.
# TODO: promote collection -> topic to a base class method so that it can be
Expand Down
66 changes: 31 additions & 35 deletions chromadb/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,8 +312,10 @@ def __call__(self, uris: URIs) -> L:
...


def validate_ids(ids: IDs) -> IDs:
def validate_ids(ids: Optional[IDs]) -> IDs:
"""Validates ids to ensure it is a list of strings"""
if ids is None:
raise ValueError("Expected IDs to be a non-empty list of str, got None")
if not isinstance(ids, list):
raise ValueError(f"Expected IDs to be a list, got {type(ids).__name__} as IDs")
if len(ids) == 0:
Expand Down Expand Up @@ -610,59 +612,50 @@ def validate_batch_size(
record_set: RecordSet,
limits: Dict[str, Any],
) -> None:
batch_size = get_n_items_from_record_set(record_set)
batch_size = count_records(record_set)

if batch_size > limits["max_batch_size"]:
raise ValueError(
f"Batch size {batch_size} exceeds maximum batch size {limits['max_batch_size']}"
)


def validate_record_set_consistency(record_set: RecordSet) -> None:
def validate_record_set_count(record_set: RecordSet) -> None:
"""
Validate the consistency of the record set, ensuring all values are non-empty lists and have the same length.
"""
error_messages = []
field_record_counts = []
count = 0
consistency_error_found = False

for field, value in record_set.items():
if value is None:
continue

if not isinstance(value, list):
error_messages.append(
f"Expected field {field} to be a list, got {type(value).__name__}"
)
continue
counts = {
field: len(
cast(Union[IDs, Metadatas, Embeddings, Documents, Images, URIs], value)
)
for field, value in record_set.items()
if value is not None
}

n_items = len(value)
if n_items == 0:
error_messages.append(
f"Expected field {field} to be a non-empty list, got an empty list"
)
continue
if len(counts) == 0:
raise ValueError("Expected at least one record set field to be non-empty")

field_record_counts.append(f"{field}: ({n_items})")
if count == 0:
count = n_items
elif count != n_items:
consistency_error_found = True
if any(count == 0 for count in counts.values()):
error_messages.append(
f"Expected all fields to be non-empty lists, got empty lists in: {', '.join(field for field, count in counts.items() if count == 0)}"
)

if consistency_error_found:
if len(counts) > 1:
field_record_counts = [f"{field}: ({count})" for field, count in counts.items()]
error_messages.append(
f"Inconsistent number of records: {', '.join(field_record_counts)}"
)

if len(error_messages) > 0:
if error_messages:
raise ValueError(", ".join(error_messages))


def validate_record_set(
record_set: RecordSet,
require_data: bool,
) -> None:
validate_record_set_count(record_set)
validate_ids(record_set["ids"])
validate_embeddings(record_set["embeddings"]) if record_set[
"embeddings"
Expand All @@ -682,27 +675,30 @@ def validate_record_set(
if not record_set_contains_one_of(record_set, include=required_fields):
raise ValueError(f"You must provide one of {', '.join(required_fields)}")

valid_ids = record_set["ids"]
# The ID validator checks for None
valid_ids = cast(IDs, record_set["ids"])
for key in ["embeddings", "metadatas", "documents", "images", "uris"]:
if record_set[key] is not None and len(record_set[key]) != len(valid_ids): # type: ignore[literal-required]
raise ValueError(
f"Number of {key} {len(record_set[key])} must match number of ids {len(valid_ids)}" # type: ignore[literal-required]
)


def get_n_items_from_record_set(
def count_records(
record_set: RecordSet,
) -> int:
"""
Get the number of items in the record set.
"""

validate_record_set_consistency(record_set)
validate_record_set_count(record_set)
for value in record_set.values():
if isinstance(value, list) and len(value) > 0:
return len(value)
if value is not None:
return len(
cast(Union[IDs, Embeddings, Metadatas, Documents, Images, URIs], value)
)

return "", 0
raise ValueError("Expected at least one record set field to be non-empty")


def convert_np_embeddings_to_list(embeddings: Embeddings) -> PyEmbeddings:
Expand Down
2 changes: 1 addition & 1 deletion chromadb/test/api/test_api_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ def test_update_with_none_ids(client: ClientAPI) -> None:
collection = client.create_collection("test")
with pytest.raises(ValueError) as e:
collection.update(ids=None, embeddings=[[0.1, 0.2, 0.3]]) # type: ignore[arg-type]
assert "You must provide ids when updating." in str(e)
assert e.match("You must provide ids when updating.")
2 changes: 1 addition & 1 deletion chromadb/test/api/test_api_upsert.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ def test_upsert_with_none_ids(client: ClientAPI) -> None:
collection = client.create_collection("test")
with pytest.raises(ValueError) as e:
collection.upsert(ids=None, embeddings=[[0.1, 0.2, 0.3]]) # type: ignore[arg-type]
assert "You must provide ids when upserting." in str(e)
assert e.match("You must provide ids when upserting.")
Loading

0 comments on commit b726474

Please sign in to comment.