Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Generate IDs when not given in add #2699

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
64b775f
lint
spikechroma Aug 22, 2024
0bfe74c
fix types
spikechroma Aug 26, 2024
73d6120
update interface
spikechroma Aug 27, 2024
6307a9e
fix tests
spikechroma Aug 27, 2024
c571525
update validation logic
spikechroma Aug 28, 2024
06220f8
fix lint
spikechroma Aug 28, 2024
65d4ba3
update validate batch type compatability
spikechroma Aug 28, 2024
56c1b0c
update fastapi types
spikechroma Aug 28, 2024
b55d5d3
remove default to empty array
spikechroma Aug 28, 2024
229606d
fix persist test
spikechroma Aug 28, 2024
eff4b59
make id optional and fix record set generation strategy to allow for …
spikechroma Aug 28, 2024
01a9fbe
make changes for optional ids
spikechroma Aug 28, 2024
66e9e6b
make changes for optional ids
spikechroma Aug 28, 2024
292b79c
make changes for optional ids
spikechroma Aug 28, 2024
d69c8e6
make changes for optional ids
spikechroma Aug 28, 2024
c104871
fix type error
spikechroma Aug 28, 2024
ac0003b
fix type error
spikechroma Aug 28, 2024
4423c8d
fix tests
spikechroma Aug 28, 2024
1acb1bf
reduce entropy in property testing
spikechroma Aug 28, 2024
d5a9cea
update ts client
spikechroma Aug 28, 2024
12953d8
lint
spikechroma Aug 28, 2024
261a2bc
create a func for getting record_set len
spikechroma Aug 28, 2024
7fbeadb
fix tests
spikechroma Aug 29, 2024
d0292f2
add comment
spikechroma Aug 29, 2024
a60ae65
fix logic error in add_embeddings
spikechroma Aug 29, 2024
5ad3506
lint
spikechroma Aug 29, 2024
04274a7
update property test
spikechroma Aug 30, 2024
78cc6a1
fix ndarray error
spikechroma Aug 30, 2024
b5ee7bd
fix tests
spikechroma Aug 30, 2024
7b9ba12
update ts client and fix cross version persist test
spikechroma Aug 30, 2024
3a7e3ee
minor updates
spikechroma Aug 30, 2024
8a6b13d
revert changes
spikechroma Sep 2, 2024
4093400
create a new func for ensuring record set consistency
spikechroma Sep 3, 2024
2fef1ec
fix broken tests
spikechroma Sep 3, 2024
d127cef
update property tests to handle validation error
spikechroma Sep 9, 2024
658179f
fix test fail
spikechroma Sep 9, 2024
0f76054
fix broken tests
spikechroma Sep 10, 2024
9fcab83
revert changes
spikechroma Sep 10, 2024
3d8d8e9
fix conflicts
spikechroma Sep 10, 2024
3954c87
update doc strings, error messages and ignore tags
spikechroma Sep 10, 2024
ac9051c
fix tests
spikechroma Sep 10, 2024
b5baff1
fix tags and update ids handling logic in cross version test
spikechroma Sep 10, 2024
b6b9fe8
update count records function
spikechroma Sep 10, 2024
6013e3f
refactor
spikechroma Sep 10, 2024
2d96f03
add additional validations
spikechroma Sep 11, 2024
a4d9f03
update logic for getting n items from record set state
spikechroma Sep 11, 2024
da08646
modify ann accuracy and count
spikechroma Sep 11, 2024
2927c01
fix error
spikechroma Sep 11, 2024
e9837f1
fix tag
spikechroma Sep 11, 2024
9060462
fix tag
spikechroma Sep 11, 2024
5d66811
fix tag
spikechroma Sep 11, 2024
f04bc38
turn on can ids be empty for test add medium
spikechroma Sep 11, 2024
3215e7d
update persist test to take in state record set
spikechroma Sep 11, 2024
33ecb52
add id validations
spikechroma Sep 11, 2024
90932a5
add tests for upsert and update with none ids
spikechroma Sep 11, 2024
0d29496
change func header
spikechroma Sep 11, 2024
d2dcb57
change func header
spikechroma Sep 11, 2024
b726474
Cleanup & Organization
spikechroma Sep 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions chromadb/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
URIs,
Where,
QueryResult,
AddResult,
GetResult,
WhereDocument,
)
Expand Down Expand Up @@ -115,13 +116,13 @@ def delete_collection(
@abstractmethod
def _add(
self,
ids: IDs,
collection_id: UUID,
embeddings: Embeddings,
ids: Optional[IDs] = None,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
) -> AddResult:
"""[Internal] Add embeddings to a collection specified by UUID.
If (some) ids already exist, only the new embeddings will be added.

Expand Down
5 changes: 3 additions & 2 deletions chromadb/api/async_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Where,
QueryResult,
GetResult,
AddResult,
WhereDocument,
)
from chromadb.config import Component, Settings
Expand Down Expand Up @@ -106,13 +107,13 @@ async def delete_collection(
@abstractmethod
async def _add(
self,
ids: IDs,
collection_id: UUID,
embeddings: Embeddings,
ids: Optional[IDs] = None,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
) -> AddResult:
"""[Internal] Add embeddings to a collection specified by UUID.
If (some) ids already exist, only the new embeddings will be added.

Expand Down
5 changes: 3 additions & 2 deletions chromadb/api/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
EmbeddingFunction,
Embeddings,
GetResult,
AddResult,
IDs,
Include,
Loadable,
Expand Down Expand Up @@ -256,13 +257,13 @@ async def delete_collection(
@override
async def _add(
self,
ids: IDs,
collection_id: UUID,
embeddings: Embeddings,
ids: Optional[IDs] = None,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
) -> AddResult:
spikechroma marked this conversation as resolved.
Show resolved Hide resolved
return await self._server._add(
ids=ids,
collection_id=collection_id,
Expand Down
118 changes: 73 additions & 45 deletions chromadb/api/async_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from uuid import UUID
import urllib.parse
import orjson
from typing import Any, Optional, cast, Tuple, Sequence, Dict
from typing import Any, Optional, cast, Sequence, Dict
import logging
import httpx
from overrides import override
Expand Down Expand Up @@ -31,9 +31,11 @@
Where,
WhereDocument,
GetResult,
AddResult,
QueryResult,
CollectionMetadata,
validate_batch,
validate_batch_size,
RecordSet,
convert_np_embeddings_to_list,
)

Expand Down Expand Up @@ -413,16 +415,10 @@ async def _delete(

return cast(IDs, resp_json)

@trace_method("AsyncFastAPI._submit_batch", OpenTelemetryGranularity.ALL)
async def _submit_batch(
@trace_method("AsyncFastAPI._submit_record_set", OpenTelemetryGranularity.ALL)
async def _submit_record_set(
self,
batch: Tuple[
IDs,
Optional[PyEmbeddings],
Optional[Metadatas],
Optional[Documents],
Optional[URIs],
],
record_set: RecordSet,
url: str,
) -> Any:
"""
Expand All @@ -432,35 +428,55 @@ async def _submit_batch(
"post",
url,
json={
"ids": batch[0],
"embeddings": batch[1],
"metadatas": batch[2],
"documents": batch[3],
"uris": batch[4],
"ids": record_set["ids"],
"embeddings": record_set["embeddings"],
"metadatas": record_set["metadatas"],
"documents": record_set["documents"],
"uris": record_set["uris"],
},
)

@trace_method("AsyncFastAPI._add", OpenTelemetryGranularity.ALL)
@override
async def _add(
self,
ids: IDs,
collection_id: UUID,
embeddings: Embeddings,
ids: Optional[IDs] = None,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
batch = (
ids,
convert_np_embeddings_to_list(embeddings),
metadatas,
documents,
uris,
) -> AddResult:
record_set: RecordSet = {
"ids": ids,
"embeddings": convert_np_embeddings_to_list(embeddings)
if embeddings is not None
else None,
"metadatas": metadatas,
"documents": documents,
"uris": uris,
"images": None,
}

validate_batch_size(
record_set, {"max_batch_size": await self.get_max_batch_size()}
)

resp_json = await self._make_request(
"post",
"/collections/" + str(collection_id) + "/add",
json={
"ids": record_set["ids"],
"embeddings": record_set["embeddings"],
"metadatas": record_set["metadatas"],
"documents": record_set["documents"],
"uris": record_set["uris"],
},
)

return AddResult(
ids=resp_json["ids"],
)
validate_batch(batch, {"max_batch_size": await self.get_max_batch_size()})
await self._submit_batch(batch, "/collections/" + str(collection_id) + "/add")
return True

@trace_method("AsyncFastAPI._update", OpenTelemetryGranularity.ALL)
@override
Expand All @@ -473,19 +489,23 @@ async def _update(
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
batch = (
ids,
convert_np_embeddings_to_list(embeddings)
record_set: RecordSet = {
"ids": ids,
"embeddings": convert_np_embeddings_to_list(embeddings)
if embeddings is not None
else None,
metadatas,
documents,
uris,
"metadatas": metadatas,
"documents": documents,
"uris": uris,
"images": None,
}

validate_batch_size(
record_set, {"max_batch_size": await self.get_max_batch_size()}
)
validate_batch(batch, {"max_batch_size": await self.get_max_batch_size()})

await self._submit_batch(
batch, "/collections/" + str(collection_id) + "/update"
await self._submit_record_set(
record_set, "/collections/" + str(collection_id) + "/update"
)

return True
Expand All @@ -501,17 +521,25 @@ async def _upsert(
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
batch = (
ids,
convert_np_embeddings_to_list(embeddings),
metadatas,
documents,
uris,
record_set: RecordSet = {
"ids": ids,
"embeddings": convert_np_embeddings_to_list(embeddings)
if embeddings is not None
else None,
"metadatas": metadatas,
"documents": documents,
"uris": uris,
"images": None,
}

validate_batch_size(
record_set, {"max_batch_size": await self.get_max_batch_size()}
)
validate_batch(batch, {"max_batch_size": await self.get_max_batch_size()})
await self._submit_batch(
batch, "/collections/" + str(collection_id) + "/upsert"

await self._submit_record_set(
record_set, "/collections/" + str(collection_id) + "/upsert"
)

return True

@trace_method("AsyncFastAPI._query", OpenTelemetryGranularity.ALL)
Expand Down
5 changes: 3 additions & 2 deletions chromadb/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Loadable,
Metadatas,
QueryResult,
AddResult,
URIs,
)
from chromadb.config import Settings, System
Expand Down Expand Up @@ -208,13 +209,13 @@ def delete_collection(
@override
def _add(
self,
ids: IDs,
collection_id: UUID,
embeddings: Embeddings,
ids: Optional[IDs] = None,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
) -> AddResult:
return self._server._add(
ids=ids,
collection_id=collection_id,
Expand Down
Loading
Loading