From 35c6f86ae999740c8636b49deb0388d9db34173f Mon Sep 17 00:00:00 2001 From: Joan Fontanals Date: Fri, 17 Nov 2023 07:32:03 +0100 Subject: [PATCH] feat: add Jina AI embedding function (#1324) ## Description of changes Hey Chroma team! We just launched [Jina Embeddings](https://jina.ai/embeddings/) and would love to add a possibilty for the community to use it with JinaEmbeddingFunctions. Thanks! ## Documentation Changes Link to docs PR: https://github.com/chroma-core/docs/pull/153 --------- Signed-off-by: Joan Fontanals Martinez --- chromadb/utils/embedding_functions.py | 58 +++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/chromadb/utils/embedding_functions.py b/chromadb/utils/embedding_functions.py index 21ca2eb78e2..23826ac4a6e 100644 --- a/chromadb/utils/embedding_functions.py +++ b/chromadb/utils/embedding_functions.py @@ -256,6 +256,64 @@ def __call__(self, input: Documents) -> Embeddings: ).json() +class JinaEmbeddingFunction(EmbeddingFunction[Documents]): + """ + This class is used to get embeddings for a list of texts using the Jina AI API. + It requires an API key and a model name. The default model name is "jina-embeddings-v2-base-en". + """ + + def __init__( + self, api_key: str, model_name: str = "jina-embeddings-v2-base-en" + ): + """ + Initialize the JinaEmbeddingFunction. + + Args: + api_key (str): Your API key for the Jina AI API. + model_name (str, optional): The name of the model to use for text embeddings. Defaults to "jina-embeddings-v2-base-en". + """ + try: + import requests + except ImportError: + raise ValueError( + "The requests python package is not installed. Please install it with `pip install requests`" + ) + self._model_name = model_name + self._api_url = 'https://api.jina.ai/v1/embeddings' + self._session = requests.Session() + self._session.headers.update({"Authorization": f"Bearer {api_key}", "Accept-Encoding": "identity"}) + + def __call__(self, input: Documents) -> Embeddings: + """ + Get the embeddings for a list of texts. + + Args: + texts (Documents): A list of texts to get embeddings for. + + Returns: + Embeddings: The embeddings for the texts. + + Example: + >>> jina_ai_fn = JinaEmbeddingFunction(api_key="your_api_key") + >>> input = ["Hello, world!", "How are you?"] + >>> embeddings = jina_ai_fn(input) + """ + # Call Jina AI Embedding API + resp = self._session.post( # type: ignore + self._api_url, json={"input": input, "model": self._model_name} + ).json() + if "data" not in resp: + raise RuntimeError(resp["detail"]) + + embeddings = resp["data"] + + # Sort resulting embeddings by index + sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore + + # Return just the embeddings + return [result["embedding"] for result in sorted_embeddings] + + class InstructorEmbeddingFunction(EmbeddingFunction[Documents]): # If you have a GPU with at least 6GB try model_name = "hkunlp/instructor-xl" and device = "cuda" # for a full list of options: https://github.com/HKUNLP/instructor-embedding#model-list