Skip to content

Commit

Permalink
[ENH]: Query by collection id (#1310)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
	 - Added new `get_collection_by_id` to `BaseAPI`

## Test plan
*How are these changes tested?*

- [x] Tests pass locally with `pytest` for python

## Documentation Changes
TBD

---------

Co-authored-by: hammadb <[email protected]>
  • Loading branch information
tazarov and HammadB authored Nov 1, 2023
1 parent dbf104f commit 4ec8d7a
Show file tree
Hide file tree
Showing 16 changed files with 206 additions and 57 deletions.
6 changes: 4 additions & 2 deletions chromadb/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def create_collection(
@abstractmethod
def get_collection(
self,
name: str,
name: Optional[str] = None,
id: Optional[UUID] = None,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
) -> Collection:
"""Get a collection with the given name.
Expand Down Expand Up @@ -496,7 +497,8 @@ def create_collection(
@override
def get_collection(
self,
name: str,
name: Optional[str] = None,
id: Optional[UUID] = None,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
Expand Down
4 changes: 3 additions & 1 deletion chromadb/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,12 @@ def create_collection(
@override
def get_collection(
self,
name: str,
name: Optional[str] = None,
id: Optional[UUID] = None,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
) -> Collection:
return self._server.get_collection(
id=id,
name=name,
embedding_function=embedding_function,
tenant=self.tenant,
Expand Down
12 changes: 9 additions & 3 deletions chromadb/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,15 +250,21 @@ def create_collection(
@override
def get_collection(
self,
name: str,
name: Optional[str] = None,
id: Optional[UUID] = None,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> Collection:
"""Returns a collection"""
if (name is None and id is None) or (name is not None and id is not None):
raise ValueError("Name or id must be specified, but not both")

_params = {"tenant": tenant, "database": database}
if id is not None:
_params["type"] = str(id)
resp = self._session.get(
self._api_url + "/collections/" + name,
params={"tenant": tenant, "database": database},
self._api_url + "/collections/" + name if name else str(id), params=_params
)
raise_chroma_error(resp)
resp_json = resp.json()
Expand Down
8 changes: 7 additions & 1 deletion chromadb/api/models/Collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class Collection(BaseModel):
name: str
id: UUID
metadata: Optional[CollectionMetadata] = None
tenant: Optional[str] = None
database: Optional[str] = None
_client: "ServerAPI" = PrivateAttr()
_embedding_function: Optional[EmbeddingFunction] = PrivateAttr()

Expand All @@ -49,9 +51,13 @@ def __init__(
name: str,
id: UUID,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
tenant: Optional[str] = None,
database: Optional[str] = None,
metadata: Optional[CollectionMetadata] = None,
):
super().__init__(name=name, metadata=metadata, id=id)
super().__init__(
name=name, metadata=metadata, id=id, tenant=tenant, database=database
)
self._client = client
self._embedding_function = embedding_function

Expand Down
17 changes: 14 additions & 3 deletions chromadb/api/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ def create_collection(
name=name,
metadata=coll["metadata"], # type: ignore
embedding_function=embedding_function,
tenant=tenant,
database=database,
)

@trace_method(
Expand Down Expand Up @@ -214,13 +216,16 @@ def get_or_create_collection(
@override
def get_collection(
self,
name: str,
name: Optional[str] = None,
id: Optional[UUID] = None,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> Collection:
if id is None and name is None or (id is not None and name is not None):
raise ValueError("Name or id must be specified, but not both")
existing = self._sysdb.get_collections(
name=name, tenant=tenant, database=database
id=id, name=name, tenant=tenant, database=database
)

if existing:
Expand All @@ -230,6 +235,8 @@ def get_collection(
name=existing[0]["name"],
metadata=existing[0]["metadata"], # type: ignore
embedding_function=embedding_function,
tenant=tenant,
database=database,
)
else:
raise ValueError(f"Collection {name} does not exist.")
Expand All @@ -250,6 +257,8 @@ def list_collections(
id=db_collection["id"],
name=db_collection["name"],
metadata=db_collection["metadata"], # type: ignore
tenant=db_collection["tenant"],
database=db_collection["database"],
)
)
return collections
Expand Down Expand Up @@ -486,7 +495,9 @@ def _get(
embeddings=[r["embedding"] for r in vectors]
if "embeddings" in include
else None,
metadatas=_clean_metadatas(metadatas) if "metadatas" in include else None, # type: ignore
metadatas=_clean_metadatas(metadatas)
if "metadatas" in include
else None, # type: ignore
documents=documents if "documents" in include else None, # type: ignore
)

Expand Down
2 changes: 2 additions & 0 deletions chromadb/db/impl/grpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,8 @@ def CreateCollection(
metadata=from_proto_metadata(request.metadata),
dimension=request.dimension,
topic=self._assignment_policy.assign_collection(id),
database=database,
tenant=tenant,
)
collections[request.id] = new_collection
return CreateCollectionResponse(
Expand Down
15 changes: 14 additions & 1 deletion chromadb/db/mixins/sysdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,13 @@ def create_collection(

topic = self._assignment_policy.assign_collection(id)
collection = Collection(
id=id, topic=topic, name=name, metadata=metadata, dimension=dimension
id=id,
topic=topic,
name=name,
metadata=metadata,
dimension=dimension,
tenant=tenant,
database=database,
)

with self.tx() as cur:
Expand Down Expand Up @@ -379,6 +385,7 @@ def get_collections(

collections_t = Table("collections")
metadata_t = Table("collection_metadata")
databases_t = Table("databases")
q = (
self.querybuilder()
.from_(collections_t)
Expand All @@ -387,13 +394,17 @@ def get_collections(
collections_t.name,
collections_t.topic,
collections_t.dimension,
databases_t.name,
databases_t.tenant_id,
metadata_t.key,
metadata_t.str_value,
metadata_t.int_value,
metadata_t.float_value,
)
.left_join(metadata_t)
.on(collections_t.id == metadata_t.collection_id)
.left_join(databases_t)
.on(collections_t.database_id == databases_t.id)
.orderby(collections_t.id)
)
if id:
Expand Down Expand Up @@ -433,6 +444,8 @@ def get_collections(
name=name,
metadata=metadata,
dimension=dimension,
tenant=str(rows[0][5]),
database=str(rows[0][4]),
)
)

Expand Down
Loading

0 comments on commit 4ec8d7a

Please sign in to comment.