Skip to content

Commit

Permalink
resolve conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
spikechroma committed Aug 22, 2024
1 parent f66ed07 commit 1872593
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 36 deletions.
13 changes: 9 additions & 4 deletions chromadb/api/models/AsyncCollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from chromadb.api.types import (
URI,
AddResult,
CollectionMetadata,
Embedding,
Include,
Expand All @@ -33,7 +34,7 @@
class AsyncCollection(CollectionCommon["AsyncServerAPI"]):
async def add(
self,
ids: OneOrMany[ID],
ids: Optional[OneOrMany[ID]] = None,
embeddings: Optional[
Union[
OneOrMany[Embedding],
Expand All @@ -44,7 +45,7 @@ async def add(
documents: Optional[OneOrMany[Document]] = None,
images: Optional[OneOrMany[Image]] = None,
uris: Optional[OneOrMany[URI]] = None,
) -> None:
) -> AddResult:
"""Add embeddings to the data store.
Args:
ids: The ids of the embeddings you wish to add
Expand Down Expand Up @@ -75,14 +76,18 @@ async def add(
)

await self._client._add(
embedding_set["ids"],
cast(IDs, embedding_set["ids"]),
self.id,
cast(Embeddings, embedding_set["embeddings"]),
embedding_set["metadatas"],
embedding_set["documents"],
embedding_set["uris"],
)

return {
"ids": embedding_set["ids"],
}

async def count(self) -> int:
"""The total number of embeddings added to the database
Expand Down Expand Up @@ -303,7 +308,7 @@ async def upsert(

await self._client._upsert(
collection_id=self.id,
ids=embedding_set["ids"],
ids=cast(IDs, embedding_set["ids"]),
embeddings=cast(Embeddings, embedding_set["embeddings"]),
metadatas=embedding_set["metadatas"],
documents=embedding_set["documents"],
Expand Down
15 changes: 10 additions & 5 deletions chromadb/api/models/Collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Include,
Metadata,
Document,
AddResult,
Image,
Where,
IDs,
Expand Down Expand Up @@ -40,7 +41,7 @@ def count(self) -> int:

def add(
self,
ids: OneOrMany[ID],
ids: Optional[OneOrMany[ID]] = None,
embeddings: Optional[ # type: ignore[type-arg]
Union[
OneOrMany[Embedding],
Expand All @@ -51,7 +52,7 @@ def add(
documents: Optional[OneOrMany[Document]] = None,
images: Optional[OneOrMany[Image]] = None,
uris: Optional[OneOrMany[URI]] = None,
) -> None:
) -> AddResult:
"""Add embeddings to the data store.
Args:
ids: The ids of the embeddings you wish to add
Expand Down Expand Up @@ -82,14 +83,18 @@ def add(
)

self._client._add(
embedding_set["ids"],
cast(IDs, embedding_set["ids"]),
self.id,
cast(Embeddings, embedding_set["embeddings"]),
embedding_set["metadatas"],
embedding_set["documents"],
embedding_set["uris"],
)

return {
"ids": embedding_set["ids"],
}

def get(
self,
ids: Optional[OneOrMany[ID]] = None,
Expand Down Expand Up @@ -263,7 +268,7 @@ def update(

self._client._update(
self.id,
embedding_set["ids"],
cast(IDs, embedding_set["ids"]),
cast(Embeddings, embedding_set["embeddings"]),
embedding_set["metadatas"],
embedding_set["documents"],
Expand Down Expand Up @@ -301,7 +306,7 @@ def upsert(

self._client._upsert(
collection_id=self.id,
ids=embedding_set["ids"],
ids=cast(IDs, embedding_set["ids"]),
embeddings=cast(Embeddings, embedding_set["embeddings"]),
metadatas=embedding_set["metadatas"],
documents=embedding_set["documents"],
Expand Down
61 changes: 48 additions & 13 deletions chromadb/api/models/CollectionCommon.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
cast,
)
import numpy as np
from uuid import UUID
from uuid import UUID, uuid4

import chromadb.utils.embedding_functions as ef
from chromadb.api.types import (
Expand Down Expand Up @@ -151,7 +151,7 @@ def get_model(self) -> CollectionModel:

def _unpack_embedding_set(
self,
ids: OneOrMany[ID],
ids: Optional[OneOrMany[ID]],
embeddings: Optional[
Union[
OneOrMany[Embedding],
Expand Down Expand Up @@ -181,7 +181,7 @@ def _unpack_embedding_set(

def _validate_embedding_set(
self,
ids: IDs,
ids: Optional[IDs],
embeddings: Optional[Embeddings],
metadatas: Optional[Metadatas],
documents: Optional[Documents],
Expand Down Expand Up @@ -215,7 +215,7 @@ def _validate_embedding_set(
)

# Only one of documents or images can be provided
if valid_documents is not None and valid_images is not None:
if documents is not None and images is not None:
raise ValueError("You can only provide documents or images, not both.")

# Check that, if they're provided, the lengths of the arrays match the length of ids
Expand All @@ -227,17 +227,17 @@ def _validate_embedding_set(
raise ValueError(
f"Number of metadatas {len(valid_metadatas)} must match number of ids {len(valid_ids)}"
)
if valid_documents is not None and len(valid_documents) != len(valid_ids):
if documents is not None and len(documents) != len(valid_ids):
raise ValueError(
f"Number of documents {len(valid_documents)} must match number of ids {len(valid_ids)}"
f"Number of documents {len(documents)} must match number of ids {len(valid_ids)}"
)
if valid_images is not None and len(valid_images) != len(valid_ids):
if images is not None and len(images) != len(valid_ids):
raise ValueError(
f"Number of images {len(valid_images)} must match number of ids {len(valid_ids)}"
f"Number of images {len(images)} must match number of ids {len(valid_ids)}"
)
if valid_uris is not None and len(valid_uris) != len(valid_ids):
if uris is not None and len(uris) != len(valid_ids):
raise ValueError(
f"Number of uris {len(valid_uris)} must match number of ids {len(valid_ids)}"
f"Number of uris {len(uris)} must match number of ids {len(valid_ids)}"
)

def _compute_embeddings(
Expand Down Expand Up @@ -426,9 +426,36 @@ def _update_model_after_modify_success(
if metadata:
self._model["metadata"] = metadata

@staticmethod
def _generate_ids_when_not_present(
ids: Optional[IDs],
documents: Optional[Documents],
uris: Optional[URIs],
images: Optional[Images],
embeddings: Optional[Embeddings],
) -> IDs:
if ids is not None and len(ids) > 0:
return ids

n = 0
if documents is not None:
n = len(documents)
elif uris is not None:
n = len(uris)
elif images is not None:
n = len(images)
elif embeddings is not None:
n = len(embeddings)

generated_ids = []
for _ in range(n):
generated_ids.append(str(uuid4()))

return generated_ids

def _process_add_request(
self,
ids: OneOrMany[ID],
ids: Optional[OneOrMany[ID]],
embeddings: Optional[
Union[
OneOrMany[Embedding],
Expand All @@ -455,8 +482,16 @@ def _process_add_request(
else None
)

self._validate_embedding_set(
generated_ids = self._generate_ids_when_not_present(
unpacked_embedding_set["ids"],
unpacked_embedding_set["documents"],
unpacked_embedding_set["uris"],
unpacked_embedding_set["images"],
normalized_embeddings,
)

self._validate_embedding_set(
generated_ids,
normalized_embeddings,
unpacked_embedding_set["metadatas"],
unpacked_embedding_set["documents"],
Expand All @@ -473,7 +508,7 @@ def _process_add_request(
)

return {
"ids": unpacked_embedding_set["ids"],
"ids": generated_ids,
"embeddings": prepared_embeddings,
"metadatas": unpacked_embedding_set["metadatas"],
"documents": unpacked_embedding_set["documents"],
Expand Down
8 changes: 8 additions & 0 deletions chromadb/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,10 @@ class RecordSet(TypedDict):
L = TypeVar("L", covariant=True, bound=Loadable)


class AddResult(TypedDict):
ids: IDs


class GetResult(TypedDict):
ids: List[ID]
embeddings: Optional[List[Embedding]]
Expand Down Expand Up @@ -290,6 +294,10 @@ def validate_ids(ids: IDs) -> IDs:
for id_ in ids:
if not isinstance(id_, str):
raise ValueError(f"Expected ID to be a str, got {id_}")

if len(id_) == 0:
raise ValueError("Expected ID to be a non-empty str, got an empty string")

if id_ in seen:
dups.add(id_)
else:
Expand Down
25 changes: 16 additions & 9 deletions chromadb/test/property/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,17 +458,26 @@ def recordsets(
num_unique_metadata: Optional[int] = None,
min_metadata_size: int = 0,
max_metadata_size: Optional[int] = None,
# ids can only be optional for add operations
for_add: bool = False,
) -> RecordSet:
collection = draw(collection_strategy)

ids = list(
draw(st.lists(id_strategy, min_size=min_size, max_size=max_size, unique=True))
)

# This probablistic event is used to mimic user behavior when they don't provide ids
if for_add and np.random.rand() < 0.5:
ids = []

n = len(ids)
if len(ids) == 0:
n = int(draw(st.integers(min_value=min_size, max_value=max_size)))

embeddings: Optional[Embeddings] = None
if collection.has_embeddings:
embeddings = create_embeddings(collection.dimension, len(ids), collection.dtype)
num_metadata = num_unique_metadata if num_unique_metadata is not None else len(ids)
embeddings = create_embeddings(collection.dimension, n, collection.dtype)
num_metadata = num_unique_metadata if num_unique_metadata is not None else n
generated_metadatas = draw(
st.lists(
metadata(
Expand All @@ -479,20 +488,18 @@ def recordsets(
)
)
metadatas = []
for i in range(len(ids)):
for i in range(n):
metadatas.append(generated_metadatas[i % len(generated_metadatas)])

documents: Optional[Documents] = None
if collection.has_documents:
documents = draw(
st.lists(document(collection), min_size=len(ids), max_size=len(ids))
)
documents = draw(st.lists(document(collection), min_size=n, max_size=n))

# in the case where we have a single record, sometimes exercise
# the code that handles individual values rather than lists.
# In this case, any field may be a list or a single value.
if len(ids) == 1:
single_id: Union[str, List[str]] = ids[0] if draw(st.booleans()) else ids
if n == 1:
single_id: Union[str, List[str]] = ids[0] if len(ids) == 1 else ids
single_embedding = (
embeddings[0]
if embeddings is not None and draw(st.booleans())
Expand Down
23 changes: 18 additions & 5 deletions chromadb/test/property/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@
from hypothesis import given, settings, HealthCheck
from typing import Dict, Set, cast, Union, DefaultDict, Any, List
from dataclasses import dataclass
from chromadb.api.types import ID, Embeddings, Include, IDs, validate_embeddings
from chromadb.api.types import (
ID,
Embeddings,
Include,
IDs,
validate_embeddings,
)
from chromadb.config import System
import chromadb.errors as errors
from chromadb.api import ClientAPI
Expand Down Expand Up @@ -104,7 +110,7 @@ def teardown(self) -> None:

@rule(
target=embedding_ids,
record_set=strategies.recordsets(collection_st),
record_set=strategies.recordsets(collection_st, for_add=True),
)
def add_embeddings(self, record_set: strategies.RecordSet) -> MultipleResults[ID]:
trace("add_embeddings")
Expand All @@ -114,7 +120,10 @@ def add_embeddings(self, record_set: strategies.RecordSet) -> MultipleResults[ID
record_set
)

if len(normalized_record_set["ids"]) > 0:
if (
normalized_record_set["metadatas"] is not None
and len(normalized_record_set["metadatas"]) > 0
):
trace("add_more_embeddings")

intersection = set(normalized_record_set["ids"]).intersection(
Expand All @@ -141,9 +150,13 @@ def add_embeddings(self, record_set: strategies.RecordSet) -> MultipleResults[ID
return multiple(*filtered_record_set["ids"])

else:
self.collection.add(**normalized_record_set) # type: ignore[arg-type]
result = self.collection.add(**normalized_record_set) # type: ignore[arg-type]
ids = result["ids"]
normalized_record_set["ids"] = ids

self._upsert_embeddings(cast(strategies.RecordSet, normalized_record_set))
return multiple(*normalized_record_set["ids"])

return multiple(*ids)

@rule(ids=st.lists(consumes(embedding_ids), min_size=1))
def delete_by_ids(self, ids: IDs) -> None:
Expand Down
5 changes: 5 additions & 0 deletions chromadb/test/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1207,6 +1207,11 @@ def test_invalid_id(client):
collection.add(embeddings=[0, 0, 0], ids=[1], metadatas=[{}])
assert "ID" in str(e.value)

# Upsert with an empty id
with pytest.raises(ValueError) as e:
collection.upsert(embeddings=[0, 0, 0], ids=[""])
assert "non-empty" in str(e.value)

# Get with non-list id
with pytest.raises(ValueError) as e:
collection.get(ids=1)
Expand Down

0 comments on commit 1872593

Please sign in to comment.