Skip to content

Commit

Permalink
Add Amazon Bedrock Embedding function (#1361)
Browse files Browse the repository at this point in the history
https://docs.aws.amazon.com/bedrock/latest/userguide/embeddings.html

## Description of changes

 - New functionality
	 - Support Amazon Bedrock embedding function

## Test plan

- [ ] Tests pass locally with `pytest` for python, `yarn test` for js

Tested locally by given profile_name with appropreate `~/.aws/config`

```py
>>> import boto3
>>> from chromadb.utils.embedding_functions import AmazonBedrockEmbeddingFunction
>>> session = boto3.Session(profile_name="myprofile", region_name="us-east-1")
>>> ef = AmazonBedrockEmbeddingFunction(session=session)
>>> ef(["Hello Bedrock"])
[[-0.73046875, 0.390625, 0.24511719, 0.111816406, 0.83203125, 0.79296875,...,]]
```

## Documentation Changes
Written docstrings as much as possible.

---------

Co-authored-by: Ben Eggers <[email protected]>
  • Loading branch information
chezou and beggers authored Dec 20, 2023
1 parent a02a0d7 commit 2202df8
Showing 1 changed file with 49 additions and 0 deletions.
49 changes: 49 additions & 0 deletions chromadb/utils/embedding_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,55 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings:
return embeddings


class AmazonBedrockEmbeddingFunction(EmbeddingFunction[Documents]):
def __init__(
self,
session: "boto3.Session", # Quote for forward reference
model_name: str = "amazon.titan-embed-text-v1",
**kwargs: Any,
):
"""Initialize AmazonBedrockEmbeddingFunction.
Args:
session (boto3.Session): The boto3 session to use.
model_name (str, optional): Identifier of the model, defaults to "amazon.titan-embed-text-v1"
**kwargs: Additional arguments to pass to the boto3 client.
Example:
>>> import boto3
>>> session = boto3.Session(profile_name="profile", region_name="us-east-1")
>>> bedrock = AmazonBedrockEmbeddingFunction(session=session)
>>> texts = ["Hello, world!", "How are you?"]
>>> embeddings = bedrock(texts)
"""

self._model_name = model_name

self._client = session.client(
service_name="bedrock-runtime",
**kwargs,
)

def __call__(self, input: Documents) -> Embeddings:
import json

accept = "application/json"
content_type = "application/json"
embeddings = []
for text in input:
input_body = {"inputText": text}
body = json.dumps(input_body)
response = self._client.invoke_model(
body=body,
modelId=self._model_name,
accept=accept,
contentType=content_type,
)
embedding = json.load(response.get("body")).get("embedding")
embeddings.append(embedding)
return embeddings


class HuggingFaceEmbeddingServer(EmbeddingFunction[Documents]):
"""
This class is used to get embeddings for a list of texts using the HuggingFace Embedding server (https://github.com/huggingface/text-embeddings-inference).
Expand Down

0 comments on commit 2202df8

Please sign in to comment.