diff --git a/chromadb/utils/embedding_functions.py b/chromadb/utils/embedding_functions.py index 21ca2eb78e2..23826ac4a6e 100644 --- a/chromadb/utils/embedding_functions.py +++ b/chromadb/utils/embedding_functions.py @@ -256,6 +256,64 @@ def __call__(self, input: Documents) -> Embeddings: ).json() +class JinaEmbeddingFunction(EmbeddingFunction[Documents]): + """ + This class is used to get embeddings for a list of texts using the Jina AI API. + It requires an API key and a model name. The default model name is "jina-embeddings-v2-base-en". + """ + + def __init__( + self, api_key: str, model_name: str = "jina-embeddings-v2-base-en" + ): + """ + Initialize the JinaEmbeddingFunction. + + Args: + api_key (str): Your API key for the Jina AI API. + model_name (str, optional): The name of the model to use for text embeddings. Defaults to "jina-embeddings-v2-base-en". + """ + try: + import requests + except ImportError: + raise ValueError( + "The requests python package is not installed. Please install it with `pip install requests`" + ) + self._model_name = model_name + self._api_url = 'https://api.jina.ai/v1/embeddings' + self._session = requests.Session() + self._session.headers.update({"Authorization": f"Bearer {api_key}", "Accept-Encoding": "identity"}) + + def __call__(self, input: Documents) -> Embeddings: + """ + Get the embeddings for a list of texts. + + Args: + texts (Documents): A list of texts to get embeddings for. + + Returns: + Embeddings: The embeddings for the texts. + + Example: + >>> jina_ai_fn = JinaEmbeddingFunction(api_key="your_api_key") + >>> input = ["Hello, world!", "How are you?"] + >>> embeddings = jina_ai_fn(input) + """ + # Call Jina AI Embedding API + resp = self._session.post( # type: ignore + self._api_url, json={"input": input, "model": self._model_name} + ).json() + if "data" not in resp: + raise RuntimeError(resp["detail"]) + + embeddings = resp["data"] + + # Sort resulting embeddings by index + sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore + + # Return just the embeddings + return [result["embedding"] for result in sorted_embeddings] + + class InstructorEmbeddingFunction(EmbeddingFunction[Documents]): # If you have a GPU with at least 6GB try model_name = "hkunlp/instructor-xl" and device = "cuda" # for a full list of options: https://github.com/HKUNLP/instructor-embedding#model-list