Skip to content

Commit

Permalink
support of OpenAI package v1.X.X for utils.OpenAIEmbeddingFunction, d…
Browse files Browse the repository at this point in the history
…eployment_id parameter for openai v0.X.X (#1338)

## Description of changes
- Add support of OpenAI package v1.X.X for utils.OpenAIEmbeddingFunction
- Add Azure OpenAI Deployment ID parameter for openai v0.X.X lib in
utils.OpenAIEmbeddingFunction

## Test plan
*How are these changes tested?*

Tested as dependency of https://github.com/Nayjest/ai-microcore with
Azure & openai packages v0.28.1 & v1.0.1, v1.1.0
  • Loading branch information
Nayjest authored Nov 8, 2023
1 parent 93993bb commit a14f158
Showing 1 changed file with 44 additions and 8 deletions.
52 changes: 44 additions & 8 deletions chromadb/utils/embedding_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(
api_base: Optional[str] = None,
api_type: Optional[str] = None,
api_version: Optional[str] = None,
deployment_id: Optional[str] = None,
):
"""
Initialize the OpenAIEmbeddingFunction.
Expand All @@ -103,6 +104,7 @@ def __init__(
api_version (str, optional): The api version for the API. If not provided,
it will use the api version for the OpenAI API. This can be used to
point to a different deployment, such as an Azure deployment.
deployment_id (str, optional): Deployment ID for Azure OpenAI.
"""
try:
Expand All @@ -126,27 +128,61 @@ def __init__(
if api_version is not None:
openai.api_version = api_version

self._api_type = api_type
if api_type is not None:
openai.api_type = api_type

if organization_id is not None:
openai.organization = organization_id

self._client = openai.Embedding
self._v1 = openai.__version__.startswith('1.')
if self._v1:
if api_type == "azure":
self._client = openai.AzureOpenAI(
api_key=api_key,
api_version=api_version,
azure_endpoint=api_base
).embeddings
else:
self._client = openai.OpenAI(
api_key=api_key,
base_url=api_base
).embeddings
else:
self._client = openai.Embedding
self._model_name = model_name
self._deployment_id = deployment_id

def __call__(self, input: Documents) -> Embeddings:
# replace newlines, which can negatively affect performance.
input = [t.replace("\n", " ") for t in input]

# Call the OpenAI Embedding API
embeddings = self._client.create(input=input, engine=self._model_name)["data"]

# Sort resulting embeddings by index
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore

# Return just the embeddings
return [result["embedding"] for result in sorted_embeddings]
if self._v1:
embeddings = self._client.create(
input=input,
model=self._deployment_id or self._model_name
).data

# Sort resulting embeddings by index
sorted_embeddings = sorted(embeddings, key=lambda e: e.index) # type: ignore

# Return just the embeddings
return [result.embedding for result in sorted_embeddings]
else:
if self._api_type == "azure":
embeddings = self._client.create(
input=input,
engine=self._deployment_id or self._model_name
)["data"]
else:
embeddings = self._client.create(input=input, model=self._model_name)["data"]

# Sort resulting embeddings by index
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore

# Return just the embeddings
return [result["embedding"] for result in sorted_embeddings]


class CohereEmbeddingFunction(EmbeddingFunction[Documents]):
Expand Down

0 comments on commit a14f158

Please sign in to comment.