Skip to content

Commit

Permalink
feat: add Jina AI embedding function (#1324)
Browse files Browse the repository at this point in the history
## Description of changes

Hey Chroma team!

We just launched [Jina Embeddings](https://jina.ai/embeddings/) and
would love to add a possibilty for the community to use it with
JinaEmbeddingFunctions.

Thanks!

## Documentation Changes
Link to docs PR: chroma-core/docs#153

---------

Signed-off-by: Joan Fontanals Martinez <[email protected]>
  • Loading branch information
JoanFM authored Nov 17, 2023
1 parent c9b9521 commit 35c6f86
Showing 1 changed file with 58 additions and 0 deletions.
58 changes: 58 additions & 0 deletions chromadb/utils/embedding_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,64 @@ def __call__(self, input: Documents) -> Embeddings:
).json()


class JinaEmbeddingFunction(EmbeddingFunction[Documents]):
"""
This class is used to get embeddings for a list of texts using the Jina AI API.
It requires an API key and a model name. The default model name is "jina-embeddings-v2-base-en".
"""

def __init__(
self, api_key: str, model_name: str = "jina-embeddings-v2-base-en"
):
"""
Initialize the JinaEmbeddingFunction.
Args:
api_key (str): Your API key for the Jina AI API.
model_name (str, optional): The name of the model to use for text embeddings. Defaults to "jina-embeddings-v2-base-en".
"""
try:
import requests
except ImportError:
raise ValueError(
"The requests python package is not installed. Please install it with `pip install requests`"
)
self._model_name = model_name
self._api_url = 'https://api.jina.ai/v1/embeddings'
self._session = requests.Session()
self._session.headers.update({"Authorization": f"Bearer {api_key}", "Accept-Encoding": "identity"})

def __call__(self, input: Documents) -> Embeddings:
"""
Get the embeddings for a list of texts.
Args:
texts (Documents): A list of texts to get embeddings for.
Returns:
Embeddings: The embeddings for the texts.
Example:
>>> jina_ai_fn = JinaEmbeddingFunction(api_key="your_api_key")
>>> input = ["Hello, world!", "How are you?"]
>>> embeddings = jina_ai_fn(input)
"""
# Call Jina AI Embedding API
resp = self._session.post( # type: ignore
self._api_url, json={"input": input, "model": self._model_name}
).json()
if "data" not in resp:
raise RuntimeError(resp["detail"])

embeddings = resp["data"]

# Sort resulting embeddings by index
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore

# Return just the embeddings
return [result["embedding"] for result in sorted_embeddings]


class InstructorEmbeddingFunction(EmbeddingFunction[Documents]):
# If you have a GPU with at least 6GB try model_name = "hkunlp/instructor-xl" and device = "cuda"
# for a full list of options: https://github.com/HKUNLP/instructor-embedding#model-list
Expand Down

0 comments on commit 35c6f86

Please sign in to comment.