diff --git a/chromadb/utils/embedding_functions.py b/chromadb/utils/embedding_functions.py index 5e38936ef6c..141964465a5 100644 --- a/chromadb/utils/embedding_functions.py +++ b/chromadb/utils/embedding_functions.py @@ -85,6 +85,7 @@ def __init__( api_base: Optional[str] = None, api_type: Optional[str] = None, api_version: Optional[str] = None, + deployment_id: Optional[str] = None, ): """ Initialize the OpenAIEmbeddingFunction. @@ -103,6 +104,7 @@ def __init__( api_version (str, optional): The api version for the API. If not provided, it will use the api version for the OpenAI API. This can be used to point to a different deployment, such as an Azure deployment. + deployment_id (str, optional): Deployment ID for Azure OpenAI. """ try: @@ -126,27 +128,61 @@ def __init__( if api_version is not None: openai.api_version = api_version + self._api_type = api_type if api_type is not None: openai.api_type = api_type if organization_id is not None: openai.organization = organization_id - self._client = openai.Embedding + self._v1 = openai.__version__.startswith('1.') + if self._v1: + if api_type == "azure": + self._client = openai.AzureOpenAI( + api_key=api_key, + api_version=api_version, + azure_endpoint=api_base + ).embeddings + else: + self._client = openai.OpenAI( + api_key=api_key, + base_url=api_base + ).embeddings + else: + self._client = openai.Embedding self._model_name = model_name + self._deployment_id = deployment_id def __call__(self, input: Documents) -> Embeddings: # replace newlines, which can negatively affect performance. input = [t.replace("\n", " ") for t in input] # Call the OpenAI Embedding API - embeddings = self._client.create(input=input, engine=self._model_name)["data"] - - # Sort resulting embeddings by index - sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore - - # Return just the embeddings - return [result["embedding"] for result in sorted_embeddings] + if self._v1: + embeddings = self._client.create( + input=input, + model=self._deployment_id or self._model_name + ).data + + # Sort resulting embeddings by index + sorted_embeddings = sorted(embeddings, key=lambda e: e.index) # type: ignore + + # Return just the embeddings + return [result.embedding for result in sorted_embeddings] + else: + if self._api_type == "azure": + embeddings = self._client.create( + input=input, + engine=self._deployment_id or self._model_name + )["data"] + else: + embeddings = self._client.create(input=input, model=self._model_name)["data"] + + # Sort resulting embeddings by index + sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore + + # Return just the embeddings + return [result["embedding"] for result in sorted_embeddings] class CohereEmbeddingFunction(EmbeddingFunction[Documents]):