diff --git a/chromadb/api/__init__.py b/chromadb/api/__init__.py index ab8d22499bc..b68d7a27d10 100644 --- a/chromadb/api/__init__.py +++ b/chromadb/api/__init__.py @@ -8,7 +8,6 @@ from chromadb.api.types import ( CollectionMetadata, Documents, - Embeddable, EmbeddingFunction, Embeddings, IDs, @@ -59,9 +58,7 @@ def create_collection( self, name: str, metadata: Optional[CollectionMetadata] = None, - embedding_function: Optional[ - EmbeddingFunction[Embeddable] - ] = ef.DefaultEmbeddingFunction(), # type: ignore + embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), get_or_create: bool = False, ) -> Collection: """Create a new collection with the given name and metadata. @@ -93,11 +90,9 @@ def create_collection( @abstractmethod def get_collection( self, - name: str, + name: Optional[str] = None, id: Optional[UUID] = None, - embedding_function: Optional[ - EmbeddingFunction[Embeddable] - ] = ef.DefaultEmbeddingFunction(), # type: ignore + embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), ) -> Collection: """Get a collection with the given name. Args: @@ -124,9 +119,7 @@ def get_or_create_collection( self, name: str, metadata: Optional[CollectionMetadata] = None, - embedding_function: Optional[ - EmbeddingFunction[Embeddable] - ] = ef.DefaultEmbeddingFunction(), # type: ignore + embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), ) -> Collection: """Get or create a collection with the given name and metadata. Args: @@ -493,9 +486,7 @@ def create_collection( self, name: str, metadata: Optional[CollectionMetadata] = None, - embedding_function: Optional[ - EmbeddingFunction[Embeddable] - ] = ef.DefaultEmbeddingFunction(), # type: ignore + embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), get_or_create: bool = False, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, @@ -506,11 +497,9 @@ def create_collection( @override def get_collection( self, - name: str, + name: Optional[str] = None, id: Optional[UUID] = None, - embedding_function: Optional[ - EmbeddingFunction[Embeddable] - ] = ef.DefaultEmbeddingFunction(), # type: ignore + embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> Collection: @@ -522,9 +511,7 @@ def get_or_create_collection( self, name: str, metadata: Optional[CollectionMetadata] = None, - embedding_function: Optional[ - EmbeddingFunction[Embeddable] - ] = ef.DefaultEmbeddingFunction(), # type: ignore + embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> Collection: diff --git a/chromadb/api/fastapi.py b/chromadb/api/fastapi.py index 6dcaaf84c44..38af7e52a91 100644 --- a/chromadb/api/fastapi.py +++ b/chromadb/api/fastapi.py @@ -14,7 +14,6 @@ from chromadb.api.models.Collection import Collection from chromadb.api.types import ( Documents, - Embeddable, Embeddings, EmbeddingFunction, IDs, @@ -220,9 +219,7 @@ def create_collection( self, name: str, metadata: Optional[CollectionMetadata] = None, - embedding_function: Optional[ - EmbeddingFunction[Embeddable] - ] = ef.DefaultEmbeddingFunction(), # type: ignore + embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), get_or_create: bool = False, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, @@ -253,9 +250,9 @@ def create_collection( @override def get_collection( self, - name: str, + name: Optional[str] = None, id: Optional[UUID] = None, - embedding_function: Optional[EmbeddingFunction[Embeddable]] = ef.DefaultEmbeddingFunction(), # type: ignore + embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> Collection: @@ -287,20 +284,17 @@ def get_or_create_collection( self, name: str, metadata: Optional[CollectionMetadata] = None, - embedding_function: Optional[EmbeddingFunction[Embeddable]] = ef.DefaultEmbeddingFunction(), # type: ignore + embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> Collection: - return cast( - Collection, - self.create_collection( - name, - metadata, - embedding_function, - get_or_create=True, - tenant=tenant, - database=database, - ), + return self.create_collection( + name, + metadata, + embedding_function, + get_or_create=True, + tenant=tenant, + database=database, ) @trace_method("FastAPI._modify", OpenTelemetryGranularity.OPERATION) @@ -353,13 +347,10 @@ def _peek( collection_id: UUID, n: int = 10, ) -> GetResult: - return cast( - GetResult, - self._get( - collection_id, - limit=n, - include=["embeddings", "documents", "metadatas"], - ), + return self._get( + collection_id, + limit=n, + include=["embeddings", "documents", "metadatas"], ) @trace_method("FastAPI._get", OpenTelemetryGranularity.OPERATION) diff --git a/chromadb/api/models/Collection.py b/chromadb/api/models/Collection.py index 058c9c86f8f..ef7c66139d2 100644 --- a/chromadb/api/models/Collection.py +++ b/chromadb/api/models/Collection.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Optional, Tuple, Any +from typing import TYPE_CHECKING, Optional, Tuple, cast, List from pydantic import BaseModel, PrivateAttr from uuid import UUID @@ -7,15 +7,9 @@ from chromadb.api.types import ( CollectionMetadata, Embedding, - Embeddings, - Embeddable, Include, Metadata, - Metadatas, Document, - Documents, - Image, - Images, Where, IDs, EmbeddingFunction, @@ -24,11 +18,7 @@ ID, OneOrMany, WhereDocument, - maybe_cast_one_to_many_ids, - maybe_cast_one_to_many_embedding, - maybe_cast_one_to_many_metadata, - maybe_cast_one_to_many_document, - maybe_cast_one_to_many_image, + maybe_cast_one_to_many, validate_ids, validate_include, validate_metadata, @@ -37,7 +27,6 @@ validate_where_document, validate_n_results, validate_embeddings, - validate_embedding_function, ) import logging @@ -54,16 +43,14 @@ class Collection(BaseModel): tenant: Optional[str] = None database: Optional[str] = None _client: "ServerAPI" = PrivateAttr() - _embedding_function: Optional[EmbeddingFunction[Embeddable]] = PrivateAttr() + _embedding_function: Optional[EmbeddingFunction] = PrivateAttr() def __init__( self, client: "ServerAPI", name: str, id: UUID, - embedding_function: Optional[ - EmbeddingFunction[Embeddable] - ] = ef.DefaultEmbeddingFunction(), # type: ignore + embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), tenant: Optional[str] = None, database: Optional[str] = None, metadata: Optional[CollectionMetadata] = None, @@ -72,11 +59,6 @@ def __init__( name=name, metadata=metadata, id=id, tenant=tenant, database=database ) self._client = client - - # Check to make sure the embedding function has the right signature, as defined by the EmbeddingFunction protocol - if embedding_function is not None: - validate_embedding_function(embedding_function) - self._embedding_function = embedding_function def __repr__(self) -> str: @@ -97,15 +79,13 @@ def add( embeddings: Optional[OneOrMany[Embedding]] = None, metadatas: Optional[OneOrMany[Metadata]] = None, documents: Optional[OneOrMany[Document]] = None, - images: Optional[OneOrMany[Image]] = None, ) -> None: """Add embeddings to the data store. Args: ids: The ids of the embeddings you wish to add - embeddings: The embeddings to add. If None, embeddings will be computed based on the documents or images using the embedding_function set for the Collection. Optional. + embeddings: The embeddings to add. If None, embeddings will be computed based on the documents using the embedding_function set for the Collection. Optional. metadatas: The metadata to associate with the embeddings. When querying, you can filter on this metadata. Optional. documents: The documents to associate with the embeddings. Optional. - images: The images to associate with the embeddings. Optional. Returns: None @@ -119,22 +99,10 @@ def add( """ - ids, embeddings, metadatas, documents, images = self._validate_embedding_set( - ids, embeddings, metadatas, documents, images + ids, embeddings, metadatas, documents = self._validate_embedding_set( + ids, embeddings, metadatas, documents ) - # We need to compute the embeddings if they're not provided - if embeddings is None: - # At this point, we know that one of documents or images are provided from the validation above - if documents is not None: - embeddings = self._embed(input=documents) - elif images is not None: - embeddings = self._embed(input=images) - else: - raise ValueError( - "You must provide embeddings, documents, or images, or an embedding function." - ) - self._client._add(ids, self.id, embeddings, metadatas, documents) def get( @@ -165,7 +133,7 @@ def get( where_document = ( validate_where_document(where_document) if where_document else None ) - ids = validate_ids(maybe_cast_one_to_many_ids(ids)) if ids else None + ids = validate_ids(maybe_cast_one_to_many(ids)) if ids else None include = validate_include(include, allow_distances=False) return self._client._get( self.id, @@ -193,7 +161,6 @@ def query( self, query_embeddings: Optional[OneOrMany[Embedding]] = None, query_texts: Optional[OneOrMany[Document]] = None, - query_images: Optional[OneOrMany[Image]] = None, n_results: int = 10, where: Optional[Where] = None, where_document: Optional[WhereDocument] = None, @@ -204,7 +171,6 @@ def query( Args: query_embeddings: The embeddings to get the closes neighbors of. Optional. query_texts: The document texts to get the closes neighbors of. Optional. - query_images: The images to get the closes neighbors of. Optional. n_results: The number of neighbors to return for each query_embedding or query_texts. Optional. where: A Where type dict used to filter results by. E.g. `{"$and": ["color" : "red", "price": {"$gte": 4.20}]}`. Optional. where_document: A WhereDocument type dict used to filter by the documents. E.g. `{$contains: {"text": "hello"}}`. Optional. @@ -214,58 +180,43 @@ def query( QueryResult: A QueryResult object containing the results. Raises: - ValueError: If you don't provide either query_embeddings, query_texts, or query_images + ValueError: If you don't provide either query_embeddings or query_texts ValueError: If you provide both query_embeddings and query_texts - ValueError: If you provide both query_embeddings and query_images - ValueError: If you provide both query_texts and query_images """ - # If neither query_embeddings nor query_texts are provided, or both are provided, raise an error - if ( - (query_embeddings is None and query_texts is None and query_images is None) - or ( - query_embeddings is not None - and (query_texts is not None or query_images is not None) - ) - or (query_texts is not None and query_images is not None) - ): - raise ValueError( - "You must provide either query embeddings, or else one of query texts or query images." - ) - where = validate_where(where) if where else None where_document = ( validate_where_document(where_document) if where_document else None ) query_embeddings = ( - validate_embeddings(maybe_cast_one_to_many_embedding(query_embeddings)) + validate_embeddings(maybe_cast_one_to_many(query_embeddings)) if query_embeddings is not None else None ) query_texts = ( - maybe_cast_one_to_many_document(query_texts) - if query_texts is not None - else None - ) - query_images = ( - maybe_cast_one_to_many_image(query_images) - if query_images is not None - else None + maybe_cast_one_to_many(query_texts) if query_texts is not None else None ) include = validate_include(include, allow_distances=True) n_results = validate_n_results(n_results) - # If query_embeddings are not provided, we need to compute them from the inputs + # If neither query_embeddings nor query_texts are provided, or both are provided, raise an error + if (query_embeddings is None and query_texts is None) or ( + query_embeddings is not None and query_texts is not None + ): + raise ValueError( + "You must provide either query embeddings or query texts, but not both" + ) + + # If query_embeddings are not provided, we need to compute them from the query_texts if query_embeddings is None: - # At this point, we know that one of query_texts or query_images are provided from the validation above - if query_texts is not None: - query_embeddings = self._embed(input=query_texts) - elif query_images is not None: - query_embeddings = self._embed(input=query_images) - else: + if self._embedding_function is None: raise ValueError( - "You must provide either query embeddings, or else one of query texts or query images." + "You must provide embeddings or a function to compute them" ) + # We know query texts is not None at this point, cast for the typechecker + query_embeddings = self._embedding_function( + cast(List[Document], query_texts) + ) if where is None: where = {} @@ -309,35 +260,23 @@ def update( embeddings: Optional[OneOrMany[Embedding]] = None, metadatas: Optional[OneOrMany[Metadata]] = None, documents: Optional[OneOrMany[Document]] = None, - images: Optional[OneOrMany[Image]] = None, ) -> None: """Update the embeddings, metadatas or documents for provided ids. Args: ids: The ids of the embeddings to update - embeddings: The embeddings to update. If None, embeddings will be computed based on the documents or images using the embedding_function set for the Collection. Optional. + embeddings: The embeddings to add. If None, embeddings will be computed based on the documents using the embedding_function set for the Collection. Optional. metadatas: The metadata to associate with the embeddings. When querying, you can filter on this metadata. Optional. documents: The documents to associate with the embeddings. Optional. - images: The images to associate with the embeddings. Optional. + Returns: None """ - ids, embeddings, metadatas, documents, images = self._validate_embedding_set( - ids, - embeddings, - metadatas, - documents, - images, - require_embeddings_or_data=False, + ids, embeddings, metadatas, documents = self._validate_embedding_set( + ids, embeddings, metadatas, documents, require_embeddings_or_documents=False ) - if embeddings is None: - if documents is not None: - embeddings = self._embed(input=documents) - elif images is not None: - embeddings = self._embed(input=images) - self._client._update(self.id, ids, embeddings, metadatas, documents) def upsert( @@ -346,7 +285,6 @@ def upsert( embeddings: Optional[OneOrMany[Embedding]] = None, metadatas: Optional[OneOrMany[Metadata]] = None, documents: Optional[OneOrMany[Document]] = None, - images: Optional[OneOrMany[Image]] = None, ) -> None: """Update the embeddings, metadatas or documents for provided ids, or create them if they don't exist. @@ -360,16 +298,10 @@ def upsert( None """ - ids, embeddings, metadatas, documents, images = self._validate_embedding_set( - ids, embeddings, metadatas, documents, images + ids, embeddings, metadatas, documents = self._validate_embedding_set( + ids, embeddings, metadatas, documents ) - if embeddings is None: - if documents is not None: - embeddings = self._embed(input=documents) - else: - embeddings = self._embed(input=images) - self._client._upsert( collection_id=self.id, ids=ids, @@ -397,7 +329,7 @@ def delete( Raises: ValueError: If you don't provide either ids, where, or where_document """ - ids = validate_ids(maybe_cast_one_to_many_ids(ids)) if ids else None + ids = validate_ids(maybe_cast_one_to_many(ids)) if ids else None where = validate_where(where) if where else None where_document = ( validate_where_document(where_document) if where_document else None @@ -411,74 +343,58 @@ def _validate_embedding_set( embeddings: Optional[OneOrMany[Embedding]], metadatas: Optional[OneOrMany[Metadata]], documents: Optional[OneOrMany[Document]], - images: Optional[OneOrMany[Image]] = None, - require_embeddings_or_data: bool = True, + require_embeddings_or_documents: bool = True, ) -> Tuple[ IDs, - Optional[Embeddings], - Optional[Metadatas], - Optional[Documents], - Optional[Images], + List[Embedding], + Optional[List[Metadata]], + Optional[List[Document]], ]: - valid_ids = validate_ids(maybe_cast_one_to_many_ids(ids)) - valid_embeddings = ( - validate_embeddings(maybe_cast_one_to_many_embedding(embeddings)) + ids = validate_ids(maybe_cast_one_to_many(ids)) + embeddings = ( + validate_embeddings(maybe_cast_one_to_many(embeddings)) if embeddings is not None else None ) - valid_metadatas = ( - validate_metadatas(maybe_cast_one_to_many_metadata(metadatas)) + metadatas = ( + validate_metadatas(maybe_cast_one_to_many(metadatas)) if metadatas is not None else None ) - valid_documents = ( - maybe_cast_one_to_many_document(documents) - if documents is not None - else None - ) - valid_images = ( - maybe_cast_one_to_many_image(images) if images is not None else None - ) + documents = maybe_cast_one_to_many(documents) if documents is not None else None - # Check that one of embeddings or ducuments or images is provided - if require_embeddings_or_data: - if ( - valid_embeddings is None - and valid_documents is None - and valid_images is None - ): - raise ValueError("You must provide embeddings, documents, or images.") - - # Only one of documents or images can be provided - if valid_documents is not None and valid_images is not None: - raise ValueError("You can only provide documents or images, not both.") + # Check that one of embeddings or documents is provided + if require_embeddings_or_documents: + if embeddings is None and documents is None: + raise ValueError( + "You must provide either embeddings or documents, or both" + ) # Check that, if they're provided, the lengths of the arrays match the length of ids - if valid_embeddings is not None and len(valid_embeddings) != len(valid_ids): + if embeddings is not None and len(embeddings) != len(ids): raise ValueError( - f"Number of embeddings {len(valid_embeddings)} must match number of ids {len(valid_ids)}" + f"Number of embeddings {len(embeddings)} must match number of ids {len(ids)}" ) - if valid_metadatas is not None and len(valid_metadatas) != len(valid_ids): + if metadatas is not None and len(metadatas) != len(ids): raise ValueError( - f"Number of metadatas {len(valid_metadatas)} must match number of ids {len(valid_ids)}" + f"Number of metadatas {len(metadatas)} must match number of ids {len(ids)}" ) - if valid_documents is not None and len(valid_documents) != len(valid_ids): + if documents is not None and len(documents) != len(ids): raise ValueError( - f"Number of documents {len(valid_documents)} must match number of ids {len(valid_ids)}" + f"Number of documents {len(documents)} must match number of ids {len(ids)}" ) - return ( - valid_ids, - valid_embeddings, - valid_metadatas, - valid_documents, - valid_images, - ) + # If document embeddings are not provided, we need to compute them + if embeddings is None and documents is not None: + if self._embedding_function is None: + raise ValueError( + "You must provide embeddings or a function to compute them" + ) + embeddings = self._embedding_function(documents) - def _embed(self, input: Any) -> Embeddings: - if self._embedding_function is None: - raise ValueError( - "You must provide an embedding function to compute embeddings." - "https://docs.trychroma.com/embeddings" - ) - return self._embedding_function(input=input) + # if embeddings is None: + # raise ValueError( + # "Something went wrong. Embeddings should be computed at this point" + # ) + + return ids, embeddings, metadatas, documents # type: ignore diff --git a/chromadb/api/types.py b/chromadb/api/types.py index 84c55257dcb..017e356ffac 100644 --- a/chromadb/api/types.py +++ b/chromadb/api/types.py @@ -1,6 +1,4 @@ -from typing import Optional, Union, TypeVar, List, Dict, Any, Tuple, cast -from numpy.typing import NDArray -import numpy as np +from typing import Optional, Union, Sequence, TypeVar, List, Dict, Any, Tuple from typing_extensions import Literal, TypedDict, Protocol import chromadb.errors as errors from chromadb.types import ( @@ -15,97 +13,27 @@ WhereDocumentOperator, WhereDocument, ) -from inspect import signature # Re-export types from chromadb.types __all__ = ["Metadata", "Where", "WhereDocument", "UpdateCollectionMetadata"] -T = TypeVar("T") -OneOrMany = Union[T, List[T]] - -# IDs ID = str IDs = List[ID] - -def maybe_cast_one_to_many_ids(target: OneOrMany[ID]) -> IDs: - if isinstance(target, str): - # One ID - return cast(IDs, [target]) - # Already a sequence - return cast(IDs, target) - - -# Embeddings Embedding = Vector Embeddings = List[Embedding] - -def maybe_cast_one_to_many_embedding(target: OneOrMany[Embedding]) -> Embeddings: - if isinstance(target, List): - # One Embedding - if isinstance(target[0], (int, float)): - return cast(Embeddings, [target]) - # Already a sequence - return cast(Embeddings, target) - - -# Metadatas Metadatas = List[Metadata] - -def maybe_cast_one_to_many_metadata(target: OneOrMany[Metadata]) -> Metadatas: - # One Metadata dict - if isinstance(target, dict): - return cast(Metadatas, [target]) - # Already a sequence - return cast(Metadatas, target) - - CollectionMetadata = Dict[str, Any] UpdateCollectionMetadata = UpdateMetadata -# Documents Document = str Documents = List[Document] - -def is_document(target: Any) -> bool: - if not isinstance(target, str): - return False - return True - - -def maybe_cast_one_to_many_document(target: OneOrMany[Document]) -> Documents: - # One Document - if is_document(target): - return cast(Documents, [target]) - # Already a sequence - return cast(Documents, target) - - -# Images -ImageDType = Union[np.uint, np.int_, np.float_] -Image = NDArray[ImageDType] -Images = List[Image] - - -def is_image(target: Any) -> bool: - if not isinstance(target, np.ndarray): - return False - if len(target.shape) < 2: - return False - return True - - -def maybe_cast_one_to_many_image(target: OneOrMany[Image]) -> Images: - if is_image(target): - return cast(Images, [target]) - # Already a sequence - return cast(Images, target) - - -Parameter = TypeVar("Parameter", Document, Image, Embedding, Metadata, ID) +Parameter = TypeVar("Parameter", Embedding, Document, Metadata, ID) +T = TypeVar("T") +OneOrMany = Union[T, List[T]] # This should ust be List[Literal["documents", "embeddings", "metadatas", "distances"]] # However, this provokes an incompatibility with the Overrides library and Python 3.7 @@ -153,29 +81,28 @@ class IndexMetadata(TypedDict): time_created: float -Embeddable = Union[Documents, Images] -D = TypeVar("D", bound=Embeddable, contravariant=True) - - -class EmbeddingFunction(Protocol[D]): - def __call__(self, input: D) -> Embeddings: +class EmbeddingFunction(Protocol): + def __call__(self, texts: Documents) -> Embeddings: ... -def validate_embedding_function( - embedding_function: EmbeddingFunction[Embeddable], -) -> None: - function_signature = signature( - embedding_function.__class__.__call__ - ).parameters.keys() - protocol_signature = signature(EmbeddingFunction.__call__).parameters.keys() +def maybe_cast_one_to_many( + target: OneOrMany[Parameter], +) -> List[Parameter]: + """Infers if target is Embedding, Metadata, or Document and casts it to a many object if its one""" - if not function_signature == protocol_signature: - raise ValueError( - f"Expected EmbeddingFunction.__call__ to have the following signature: {protocol_signature}, got {function_signature}\n" - "Please see https://docs.trychroma.com/embeddings for details of the EmbeddingFunction interface.\n" - "Please note the recent change to the EmbeddingFunction interface: https://docs.trychroma.com/migration#migration-to-0416---november-7-2023 \n" - ) + if isinstance(target, Sequence): + # One Document or ID + if isinstance(target, str) and target is not None: + return [target] + # One Embedding + if isinstance(target[0], (int, float)): + return [target] # type: ignore + # One Metadata dict + if isinstance(target, dict): + return [target] + # Already a sequence + return target # type: ignore def validate_ids(ids: IDs) -> IDs: diff --git a/chromadb/test/conftest.py b/chromadb/test/conftest.py index 087cb2271bd..401139684ab 100644 --- a/chromadb/test/conftest.py +++ b/chromadb/test/conftest.py @@ -429,7 +429,6 @@ def client(system: System) -> Generator[ClientAPI, None, None]: system.reset_state() client = ClientCreator.from_system(system) yield client - client.clear_system_cache() @pytest.fixture(scope="function") diff --git a/chromadb/test/ef/test_multimodal_ef.py b/chromadb/test/ef/test_multimodal_ef.py deleted file mode 100644 index 52213c77a4c..00000000000 --- a/chromadb/test/ef/test_multimodal_ef.py +++ /dev/null @@ -1,152 +0,0 @@ -from typing import Generator, cast -import numpy as np -import pytest -import chromadb -from chromadb.api.types import ( - Embeddable, - EmbeddingFunction, - Embeddings, - Image, - Document, -) -from chromadb.test.property.strategies import hashing_embedding_function -from chromadb.test.property.invariants import _exact_distances - - -# A 'standard' multimodal embedding function, which converts inputs to strings -# then hashes them to a fixed dimension. -class hashing_multimodal_ef(EmbeddingFunction[Embeddable]): - def __init__(self) -> None: - self._hef = hashing_embedding_function(dim=10, dtype=np.float_) - - def __call__(self, input: Embeddable) -> Embeddings: - to_texts = [str(i) for i in input] - embeddings = np.array(self._hef(to_texts)) - # Normalize the embeddings - # This is so we can generate random unit vectors and have them be close to the embeddings - embeddings /= np.linalg.norm(embeddings, axis=1, keepdims=True) - return cast(Embeddings, embeddings.tolist()) - - -def random_image() -> Image: - return np.random.randint(0, 255, size=(10, 10, 3), dtype=np.int32) - - -def random_document() -> Document: - return str(random_image()) - - -@pytest.fixture -def multimodal_collection( - default_ef: EmbeddingFunction[Embeddable] = hashing_multimodal_ef(), -) -> Generator[chromadb.Collection, None, None]: - client = chromadb.Client() - collection = client.create_collection( - name="multimodal_collection", embedding_function=default_ef - ) - yield collection - client.clear_system_cache() - - -# Test adding and querying of a multimodal collection consisting of images and documents -def test_multimodal( - multimodal_collection: chromadb.Collection, - default_ef: EmbeddingFunction[Embeddable] = hashing_multimodal_ef(), - n_examples: int = 10, - n_query_results: int = 3, -) -> None: - image_ids = [str(i) for i in range(n_examples)] - images = [random_image() for _ in range(n_examples)] - image_embeddings = default_ef(images) - - document_ids = [str(i) for i in range(n_examples, 2 * n_examples)] - documents = [random_document() for _ in range(n_examples)] - document_embeddings = default_ef(documents) - - # Trying to add a document and an image at the same time should fail - with pytest.raises( - ValueError, match="You can only provide documents or images, not both." - ): - multimodal_collection.add( - ids=image_ids[0], documents=documents[0], images=images[0] - ) - - # Add some documents - multimodal_collection.add(ids=document_ids, documents=documents) - # Add some images - multimodal_collection.add(ids=image_ids, images=images) - - # get() should return all the documents and images - # ids corresponding to images should not have documents - get_result = multimodal_collection.get(include=["documents"]) - assert len(get_result["ids"]) == len(document_ids) + len(image_ids) - for i, id in enumerate(get_result["ids"]): - assert id in document_ids or id in image_ids - assert get_result["documents"] is not None - if id in document_ids: - assert get_result["documents"][i] == documents[document_ids.index(id)] - if id in image_ids: - assert get_result["documents"][i] is None - - # Generate a random query image - query_image = random_image() - query_image_embedding = default_ef([query_image]) - - image_neighbor_indices, _ = _exact_distances( - query_image_embedding, image_embeddings + document_embeddings - ) - # Get the ids of the nearest neighbors - nearest_image_neighbor_ids = [ - image_ids[i] if i < n_examples else document_ids[i % n_examples] - for i in image_neighbor_indices[0][:n_query_results] - ] - - # Generate a random query document - query_document = random_document() - query_document_embedding = default_ef([query_document]) - document_neighbor_indices, _ = _exact_distances( - query_document_embedding, image_embeddings + document_embeddings - ) - nearest_document_neighbor_ids = [ - image_ids[i] if i < n_examples else document_ids[i % n_examples] - for i in document_neighbor_indices[0][:n_query_results] - ] - - # Querying with both images and documents should fail - with pytest.raises(ValueError): - multimodal_collection.query( - query_images=[query_image], query_texts=[query_document] - ) - - # Query with images - query_result = multimodal_collection.query( - query_images=[query_image], n_results=n_query_results, include=["documents"] - ) - - assert query_result["ids"][0] == nearest_image_neighbor_ids - - # Query with documents - query_result = multimodal_collection.query( - query_texts=[query_document], n_results=n_query_results, include=["documents"] - ) - - assert query_result["ids"][0] == nearest_document_neighbor_ids - - -@pytest.mark.xfail -def test_multimodal_update_with_image( - multimodal_collection: chromadb.Collection, -) -> None: - # Updating an entry with an existing document should remove the documentß - - document = random_document() - image = random_image() - id = "0" - - multimodal_collection.add(ids=id, documents=document) - - multimodal_collection.update(ids=id, images=image) - - get_result = multimodal_collection.get(ids=id, include=["documents"]) - assert get_result["documents"] is not None - assert get_result["documents"][0] is None diff --git a/chromadb/test/property/strategies.py b/chromadb/test/property/strategies.py index 142fbc8b3f2..3583dadfba9 100644 --- a/chromadb/test/property/strategies.py +++ b/chromadb/test/property/strategies.py @@ -1,7 +1,7 @@ import hashlib import hypothesis import hypothesis.strategies as st -from typing import Any, Optional, List, Dict, Union, cast +from typing import Any, Optional, List, Dict, Union from typing_extensions import TypedDict import numpy as np import numpy.typing as npt @@ -13,14 +13,8 @@ from dataclasses import dataclass -from chromadb.api.types import ( - Documents, - Embeddable, - EmbeddingFunction, - Embeddings, - Metadata, -) -from chromadb.types import LiteralValue, WhereOperator, LogicalOperator +from chromadb.api.types import Documents, Embeddings, Metadata +from chromadb.types import LiteralValue # Set the random seed for reproducibility np.random.seed(0) # unnecessary, hypothesis does this for us @@ -184,15 +178,15 @@ def create_embeddings( return embeddings -class hashing_embedding_function(types.EmbeddingFunction[Documents]): +class hashing_embedding_function(types.EmbeddingFunction): def __init__(self, dim: int, dtype: npt.DTypeLike) -> None: self.dim = dim self.dtype = dtype - def __call__(self, input: types.Documents) -> types.Embeddings: + def __call__(self, texts: types.Documents) -> types.Embeddings: # Hash the texts and convert to hex strings hashed_texts = [ - list(hashlib.sha256(text.encode("utf-8")).hexdigest()) for text in input + list(hashlib.sha256(text.encode("utf-8")).hexdigest()) for text in texts ] # Pad with repetition, or truncate the hex strings to the desired dimension padded_texts = [ @@ -209,17 +203,15 @@ def __call__(self, input: types.Documents) -> types.Embeddings: return embeddings -class not_implemented_embedding_function(types.EmbeddingFunction[Documents]): - def __call__(self, input: Documents) -> Embeddings: +class not_implemented_embedding_function(types.EmbeddingFunction): + def __call__(self, texts: Documents) -> Embeddings: assert False, "This embedding function is not implemented" def embedding_function_strategy( dim: int, dtype: npt.DTypeLike -) -> st.SearchStrategy[types.EmbeddingFunction[Embeddable]]: - return st.just( - cast(EmbeddingFunction[Embeddable], hashing_embedding_function(dim, dtype)) - ) +) -> st.SearchStrategy[types.EmbeddingFunction]: + return st.just(hashing_embedding_function(dim, dtype)) @dataclass @@ -232,7 +224,7 @@ class Collection: known_document_keywords: List[str] has_documents: bool = False has_embeddings: bool = False - embedding_function: Optional[types.EmbeddingFunction[Embeddable]] = None + embedding_function: Optional[types.EmbeddingFunction] = None @st.composite @@ -319,12 +311,12 @@ def metadata(draw: st.DrawFn, collection: Collection) -> types.Metadata: if collection.known_metadata_keys: for key in collection.known_metadata_keys.keys(): if key in metadata: - del metadata[key] # type: ignore + del metadata[key] # Finally, add in some of the known keys for the collection sampling_dict: Dict[str, st.SearchStrategy[Union[str, int, float]]] = { k: st.just(v) for k, v in collection.known_metadata_keys.items() } - metadata.update(draw(st.fixed_dictionaries({}, optional=sampling_dict))) # type: ignore + metadata.update(draw(st.fixed_dictionaries({}, optional=sampling_dict))) return metadata @@ -340,11 +332,11 @@ def document(draw: st.DrawFn, collection: Collection) -> types.Document: else: known_words_st = st.text( min_size=1, - alphabet=st.characters(blacklist_categories=blacklist_categories), # type: ignore + alphabet=st.characters(blacklist_categories=blacklist_categories), ) random_words_st = st.text( - min_size=1, alphabet=st.characters(blacklist_categories=blacklist_categories) # type: ignore + min_size=1, alphabet=st.characters(blacklist_categories=blacklist_categories) ) words = draw(st.lists(st.one_of(known_words_st, random_words_st), min_size=1)) return " ".join(words) @@ -495,20 +487,20 @@ def where_clause(draw: st.DrawFn, collection: Collection) -> types.Where: # Add or subtract a small number to avoid floating point rounding errors value = value + draw(st.sampled_from([1e-6, -1e-6])) - op: WhereOperator = draw(st.sampled_from(legal_ops)) + op: types.WhereOperator = draw(st.sampled_from(legal_ops)) if op is None: return {key: value} - elif op == "$in": # type: ignore + elif op == "$in": if isinstance(value, str) and not value: return {} return {key: {op: [value, *[draw(opposite_value(value)) for _ in range(3)]]}} - elif op == "$nin": # type: ignore + elif op == "$nin": if isinstance(value, str) and not value: return {} return {key: {op: [draw(opposite_value(value)) for _ in range(3)]}} else: - return {key: {op: value}} # type: ignore + return {key: {op: value}} @st.composite @@ -524,7 +516,7 @@ def where_doc_clause(draw: st.DrawFn, collection: Collection) -> types.WhereDocu def binary_operator_clause( base_st: SearchStrategy[types.Where], ) -> SearchStrategy[types.Where]: - op: SearchStrategy[LogicalOperator] = st.sampled_from(["$and", "$or"]) + op: SearchStrategy[types.LogicalOperator] = st.sampled_from(["$and", "$or"]) return st.dictionaries( keys=op, values=st.lists(base_st, max_size=2, min_size=2), @@ -536,7 +528,7 @@ def binary_operator_clause( def binary_document_operator_clause( base_st: SearchStrategy[types.WhereDocument], ) -> SearchStrategy[types.WhereDocument]: - op: SearchStrategy[LogicalOperator] = st.sampled_from(["$and", "$or"]) + op: SearchStrategy[types.LogicalOperator] = st.sampled_from(["$and", "$or"]) return st.dictionaries( keys=op, values=st.lists(base_st, max_size=2, min_size=2), diff --git a/chromadb/test/property/test_add.py b/chromadb/test/property/test_add.py index f97e33aa305..5f8991b00ed 100644 --- a/chromadb/test/property/test_add.py +++ b/chromadb/test/property/test_add.py @@ -26,7 +26,7 @@ def test_add( # TODO: Generative embedding functions coll = api.create_collection( name=collection.name, - metadata=collection.metadata, # type: ignore + metadata=collection.metadata, embedding_function=collection.embedding_function, ) normalized_record_set = invariants.wrap_all(record_set) @@ -64,7 +64,7 @@ def create_large_recordset( "metadatas": metadatas, "documents": documents, } - return cast(strategies.RecordSet, record_set) + return record_set @given(collection=collection_st) @@ -77,7 +77,7 @@ def test_add_large(api: ServerAPI, collection: strategies.Collection) -> None: ) coll = api.create_collection( name=collection.name, - metadata=collection.metadata, # type: ignore + metadata=collection.metadata, embedding_function=collection.embedding_function, ) normalized_record_set = invariants.wrap_all(record_set) @@ -107,7 +107,7 @@ def test_add_large_exceeding(api: ServerAPI, collection: strategies.Collection) ) coll = api.create_collection( name=collection.name, - metadata=collection.metadata, # type: ignore + metadata=collection.metadata, embedding_function=collection.embedding_function, ) normalized_record_set = invariants.wrap_all(record_set) @@ -157,7 +157,7 @@ def test_out_of_order_ids(api: ServerAPI) -> None: ] coll = api.create_collection( - "test", embedding_function=lambda input: [[1, 2, 3] for _ in input] # type: ignore + "test", embedding_function=lambda texts: [[1, 2, 3] for _ in texts] # type: ignore ) embeddings: Embeddings = [[1, 2, 3] for _ in ooo_ids] coll.add(ids=ooo_ids, embeddings=embeddings) @@ -174,7 +174,7 @@ def test_add_partial(api: ServerAPI) -> None: # TODO: We need to clean up the api types to support this typing coll.add( ids=["1", "2", "3"], - embeddings=[[1, 2, 3], [1, 2, 3], [1, 2, 3]], # type: ignore + embeddings=[[1, 2, 3], [1, 2, 3], [1, 2, 3]], metadatas=[{"a": 1}, None, {"a": 3}], # type: ignore documents=["a", "b", None], # type: ignore ) diff --git a/chromadb/test/property/test_cross_version_persist.py b/chromadb/test/property/test_cross_version_persist.py index 82bfc5f7cda..b5320dfe7a9 100644 --- a/chromadb/test/property/test_cross_version_persist.py +++ b/chromadb/test/property/test_cross_version_persist.py @@ -204,8 +204,8 @@ def switch_to_version(version: str) -> ModuleType: return chromadb -class not_implemented_ef(EmbeddingFunction[Documents]): - def __call__(self, input: Documents) -> Embeddings: +class not_implemented_ef(EmbeddingFunction): + def __call__(self, texts: Documents) -> Embeddings: assert False, "Embedding function should not be called" @@ -314,7 +314,7 @@ def test_cycle_versions( system.start() coll = api.get_collection( name=collection_strategy.name, - embedding_function=not_implemented_ef(), # type: ignore + embedding_function=not_implemented_ef(), ) invariants.count(coll, embeddings_strategy) invariants.metadatas_match(coll, embeddings_strategy) diff --git a/chromadb/test/test_api.py b/chromadb/test/test_api.py index d6d3e3c30a8..ed3c87ee682 100644 --- a/chromadb/test/test_api.py +++ b/chromadb/test/test_api.py @@ -3,7 +3,7 @@ import chromadb from chromadb.api.fastapi import FastAPI -from chromadb.api.types import QueryResult, EmbeddingFunction, Document +from chromadb.api.types import QueryResult from chromadb.config import Settings import chromadb.server.fastapi import pytest @@ -91,17 +91,14 @@ def test_persist_index_loading(api_fixture, request): @pytest.mark.parametrize("api_fixture", [local_persist_api]) def test_persist_index_loading_embedding_function(api_fixture, request): - class TestEF(EmbeddingFunction[Document]): - def __call__(self, input): - return [[1, 2, 3] for _ in range(len(input))] - + embedding_function = lambda x: [[1, 2, 3] for _ in range(len(x))] # noqa E731 api = request.getfixturevalue("local_persist_api") api.reset() - collection = api.create_collection("test", embedding_function=TestEF()) + collection = api.create_collection("test", embedding_function=embedding_function) collection.add(ids="id1", documents="hello") api2 = request.getfixturevalue("local_persist_api_cache_bust") - collection = api2.get_collection("test", embedding_function=TestEF()) + collection = api2.get_collection("test", embedding_function=embedding_function) nn = collection.query( query_texts="hello", @@ -114,17 +111,18 @@ def __call__(self, input): @pytest.mark.parametrize("api_fixture", [local_persist_api]) def test_persist_index_get_or_create_embedding_function(api_fixture, request): - class TestEF(EmbeddingFunction[Document]): - def __call__(self, input): - return [[1, 2, 3] for _ in range(len(input))] - + embedding_function = lambda x: [[1, 2, 3] for _ in range(len(x))] # noqa E731 api = request.getfixturevalue("local_persist_api") api.reset() - collection = api.get_or_create_collection("test", embedding_function=TestEF()) + collection = api.get_or_create_collection( + "test", embedding_function=embedding_function + ) collection.add(ids="id1", documents="hello") api2 = request.getfixturevalue("local_persist_api_cache_bust") - collection = api2.get_or_create_collection("test", embedding_function=TestEF()) + collection = api2.get_or_create_collection( + "test", embedding_function=embedding_function + ) nn = collection.query( query_texts="hello", diff --git a/chromadb/utils/embedding_functions.py b/chromadb/utils/embedding_functions.py index 5e38936ef6c..aaef53c01e2 100644 --- a/chromadb/utils/embedding_functions.py +++ b/chromadb/utils/embedding_functions.py @@ -1,21 +1,11 @@ import logging -from chromadb.api.types import ( - Document, - Documents, - Embedding, - Image, - Images, - EmbeddingFunction, - Embeddings, - is_image, - is_document, -) +from chromadb.api.types import Documents, EmbeddingFunction, Embeddings from pathlib import Path import os import tarfile import requests -from typing import Any, Dict, List, Union, cast +from typing import Any, Dict, List, cast import numpy as np import numpy.typing as npt import importlib @@ -31,7 +21,7 @@ logger = logging.getLogger(__name__) -class SentenceTransformerEmbeddingFunction(EmbeddingFunction[Documents]): +class SentenceTransformerEmbeddingFunction(EmbeddingFunction): # Since we do dynamic imports we have to type this as Any models: Dict[str, Any] = {} @@ -54,15 +44,15 @@ def __init__( self._model = self.models[model_name] self._normalize_embeddings = normalize_embeddings - def __call__(self, input: Documents) -> Embeddings: + def __call__(self, texts: Documents) -> Embeddings: return self._model.encode( # type: ignore - list(input), + list(texts), convert_to_numpy=True, normalize_embeddings=self._normalize_embeddings, ).tolist() -class Text2VecEmbeddingFunction(EmbeddingFunction[Documents]): +class Text2VecEmbeddingFunction(EmbeddingFunction): def __init__(self, model_name: str = "shibing624/text2vec-base-chinese"): try: from text2vec import SentenceModel @@ -72,11 +62,11 @@ def __init__(self, model_name: str = "shibing624/text2vec-base-chinese"): ) self._model = SentenceModel(model_name_or_path=model_name) - def __call__(self, input: Documents) -> Embeddings: - return self._model.encode(list(input), convert_to_numpy=True).tolist() # type: ignore # noqa E501 + def __call__(self, texts: Documents) -> Embeddings: + return self._model.encode(list(texts), convert_to_numpy=True).tolist() # type: ignore # noqa E501 -class OpenAIEmbeddingFunction(EmbeddingFunction[Documents]): +class OpenAIEmbeddingFunction(EmbeddingFunction): def __init__( self, api_key: Optional[str] = None, @@ -135,12 +125,12 @@ def __init__( self._client = openai.Embedding self._model_name = model_name - def __call__(self, input: Documents) -> Embeddings: + def __call__(self, texts: Documents) -> Embeddings: # replace newlines, which can negatively affect performance. - input = [t.replace("\n", " ") for t in input] + texts = [t.replace("\n", " ") for t in texts] # Call the OpenAI Embedding API - embeddings = self._client.create(input=input, engine=self._model_name)["data"] + embeddings = self._client.create(input=texts, engine=self._model_name)["data"] # Sort resulting embeddings by index sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore @@ -149,7 +139,7 @@ def __call__(self, input: Documents) -> Embeddings: return [result["embedding"] for result in sorted_embeddings] -class CohereEmbeddingFunction(EmbeddingFunction[Documents]): +class CohereEmbeddingFunction(EmbeddingFunction): def __init__(self, api_key: str, model_name: str = "large"): try: import cohere @@ -161,15 +151,15 @@ def __init__(self, api_key: str, model_name: str = "large"): self._client = cohere.Client(api_key) self._model_name = model_name - def __call__(self, input: Documents) -> Embeddings: + def __call__(self, texts: Documents) -> Embeddings: # Call Cohere Embedding API for each document. return [ embeddings - for embeddings in self._client.embed(texts=input, model=self._model_name) + for embeddings in self._client.embed(texts=texts, model=self._model_name) ] -class HuggingFaceEmbeddingFunction(EmbeddingFunction[Documents]): +class HuggingFaceEmbeddingFunction(EmbeddingFunction): """ This class is used to get embeddings for a list of texts using the HuggingFace API. It requires an API key and a model name. The default model name is "sentence-transformers/all-MiniLM-L6-v2". @@ -195,7 +185,7 @@ def __init__( self._session = requests.Session() self._session.headers.update({"Authorization": f"Bearer {api_key}"}) - def __call__(self, input: Documents) -> Embeddings: + def __call__(self, texts: Documents) -> Embeddings: """ Get the embeddings for a list of texts. @@ -212,11 +202,11 @@ def __call__(self, input: Documents) -> Embeddings: """ # Call HuggingFace Embedding API for each document return self._session.post( # type: ignore - self._api_url, json={"inputs": input, "options": {"wait_for_model": True}} + self._api_url, json={"inputs": texts, "options": {"wait_for_model": True}} ).json() -class InstructorEmbeddingFunction(EmbeddingFunction[Documents]): +class InstructorEmbeddingFunction(EmbeddingFunction): # If you have a GPU with at least 6GB try model_name = "hkunlp/instructor-xl" and device = "cuda" # for a full list of options: https://github.com/HKUNLP/instructor-embedding#model-list def __init__( @@ -234,11 +224,11 @@ def __init__( self._model = INSTRUCTOR(model_name, device=device) self._instruction = instruction - def __call__(self, input: Documents) -> Embeddings: + def __call__(self, texts: Documents) -> Embeddings: if self._instruction is None: - return self._model.encode(input).tolist() # type: ignore + return self._model.encode(texts).tolist() # type: ignore - texts_with_instructions = [[self._instruction, text] for text in input] + texts_with_instructions = [[self._instruction, text] for text in texts] return self._model.encode(texts_with_instructions).tolist() # type: ignore @@ -247,7 +237,7 @@ def __call__(self, input: Documents) -> Embeddings: # implements the same functionality as "all-MiniLM-L6-v2" from sentence-transformers. # visit https://github.com/chroma-core/onnx-embedding for the source code to generate # and verify the ONNX model. -class ONNXMiniLM_L6_V2(EmbeddingFunction[Documents]): +class ONNXMiniLM_L6_V2(EmbeddingFunction): MODEL_NAME = "all-MiniLM-L6-v2" DOWNLOAD_PATH = Path.home() / ".cache" / "chroma" / "onnx_models" / MODEL_NAME EXTRACTED_FOLDER_NAME = "onnx" @@ -384,11 +374,11 @@ def _init_model_and_tokenizer(self) -> None: providers=self._preferred_providers, ) - def __call__(self, input: Documents) -> Embeddings: + def __call__(self, texts: Documents) -> Embeddings: # Only download the model when it is actually used self._download_model_if_not_exists() self._init_model_and_tokenizer() - res = cast(Embeddings, self._forward(input).tolist()) + res = cast(Embeddings, self._forward(texts).tolist()) return res def _download_model_if_not_exists(self) -> None: @@ -423,14 +413,14 @@ def _download_model_if_not_exists(self) -> None: tar.extractall(path=self.DOWNLOAD_PATH) -def DefaultEmbeddingFunction() -> Optional[EmbeddingFunction[Documents]]: +def DefaultEmbeddingFunction() -> Optional[EmbeddingFunction]: if is_thin_client: return None else: return ONNXMiniLM_L6_V2() -class GooglePalmEmbeddingFunction(EmbeddingFunction[Documents]): +class GooglePalmEmbeddingFunction(EmbeddingFunction): """To use this EmbeddingFunction, you must have the google.generativeai Python package installed and have a PaLM API key.""" def __init__(self, api_key: str, model_name: str = "models/embedding-gecko-001"): @@ -451,16 +441,16 @@ def __init__(self, api_key: str, model_name: str = "models/embedding-gecko-001") self._palm = palm self._model_name = model_name - def __call__(self, input: Documents) -> Embeddings: + def __call__(self, texts: Documents) -> Embeddings: return [ self._palm.generate_embeddings(model=self._model_name, text=text)[ "embedding" ] - for text in input + for text in texts ] -class GoogleVertexEmbeddingFunction(EmbeddingFunction[Documents]): +class GoogleVertexEmbeddingFunction(EmbeddingFunction): # Follow API Quickstart for Google Vertex AI # https://cloud.google.com/vertex-ai/docs/generative-ai/start/quickstarts/api-quickstart # Information about the text embedding modules in Google Vertex AI @@ -476,9 +466,9 @@ def __init__( self._session = requests.Session() self._session.headers.update({"Authorization": f"Bearer {api_key}"}) - def __call__(self, input: Documents) -> Embeddings: + def __call__(self, texts: Documents) -> Embeddings: embeddings = [] - for text in input: + for text in texts: response = self._session.post( self._api_url, json={"instances": [{"content": text}]} ).json() @@ -489,62 +479,6 @@ def __call__(self, input: Documents) -> Embeddings: return embeddings -class OpenCLIPEmbeddingFunction(EmbeddingFunction[Union[Documents, Images]]): - def __init__( - self, model_name: str = "ViT-B-32", checkpoint: str = "laion2b_s34b_b79k" - ) -> None: - try: - import open_clip - except ImportError: - raise ValueError( - "The open_clip python package is not installed. Please install it with `pip install open-clip-torch`. https://github.com/mlfoundations/open_clip" - ) - try: - self._torch = importlib.import_module("torch") - except ImportError: - raise ValueError( - "The torch python package is not installed. Please install it with `pip install torch`" - ) - - try: - self._PILImage = importlib.import_module("PIL.Image") - except ImportError: - raise ValueError( - "The PIL python package is not installed. Please install it with `pip install pillow`" - ) - - model, _, preprocess = open_clip.create_model_and_transforms( - model_name=model_name, pretrained=checkpoint - ) - self._model = model - self._preprocess = preprocess - self._tokenizer = open_clip.get_tokenizer(model_name=model_name) - - def _encode_image(self, image: Image) -> Embedding: - pil_image = self._PILImage.fromarray(image) - with self._torch.no_grad(): - image_features = self._model.encode_image( - self._preprocess(pil_image).unsqueeze(0) - ) - image_features /= image_features.norm(dim=-1, keepdim=True) - return cast(Embedding, image_features.squeeze().tolist()) - - def _encode_text(self, text: Document) -> Embedding: - with self._torch.no_grad(): - text_features = self._model.encode_text(self._tokenizer(text)) - text_features /= text_features.norm(dim=-1, keepdim=True) - return cast(Embedding, text_features.squeeze().tolist()) - - def __call__(self, input: Union[Documents, Images]) -> Embeddings: - embeddings: Embeddings = [] - for item in input: - if is_image(item): - embeddings.append(self._encode_image(cast(Image, item))) - elif is_document(item): - embeddings.append(self._encode_text(cast(Document, item))) - return embeddings - - # List of all classes in this module _classes = [ name diff --git a/multimodal_ef_example.ipynb b/multimodal_ef_example.ipynb deleted file mode 100644 index 879c04454a5..00000000000 --- a/multimodal_ef_example.ipynb +++ /dev/null @@ -1,102 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "import chromadb\n", - "\n", - "client = chromadb.Client()" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "from chromadb.api.types import Embeddings, Images\n", - "from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction\n", - "\n", - "embedding_function = OpenCLIPEmbeddingFunction()" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "collection = client.create_collection('test', embedding_function=embedding_function)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "from PIL import Image\n", - "\n", - "image = np.array(Image.open('test_img.jpeg'))" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "collection.add(ids='a', images=image)\n", - "collection.add(ids='b', documents='hello world')" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'ids': ['a', 'b'],\n", - " 'embeddings': None,\n", - " 'metadatas': None,\n", - " 'documents': [None, 'hello world']}" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "collection.get(include=['documents'])" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "chroma", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.16" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -}