diff --git a/docs/user_guides/configuration-guide.md b/docs/user_guides/configuration-guide.md index 1c6f3a864..ade1bd239 100644 --- a/docs/user_guides/configuration-guide.md +++ b/docs/user_guides/configuration-guide.md @@ -385,6 +385,7 @@ The following tables lists the supported embedding providers: | OpenAI | `openai` | `text-embedding-ada-002`, etc. | | SentenceTransformers | `SentenceTransformers` | `all-MiniLM-L6-v2`, etc. | | NVIDIA AI Endpoints | `nvidia_ai_endpoints` | `nv-embed-v1`, etc. | +| AzureOpenAI | `AzureOpenAI` | `text-embedding-ada-002`, etc. ```{note} You can use any of the supported models for any of the supported embedding providers. diff --git a/nemoguardrails/embeddings/providers/__init__.py b/nemoguardrails/embeddings/providers/__init__.py index 1bdfab5dc..069cad8cb 100644 --- a/nemoguardrails/embeddings/providers/__init__.py +++ b/nemoguardrails/embeddings/providers/__init__.py @@ -18,7 +18,7 @@ from typing import Optional, Type -from . import fastembed, nim, openai, sentence_transformers +from . import azureopenai, fastembed, nim, openai, sentence_transformers from .base import EmbeddingModel from .registry import EmbeddingProviderRegistry @@ -65,6 +65,7 @@ def register_embedding_provider( register_embedding_provider(fastembed.FastEmbedEmbeddingModel) register_embedding_provider(openai.OpenAIEmbeddingModel) +register_embedding_provider(azureopenai.AzureEmbeddingModel) register_embedding_provider(sentence_transformers.SentenceTransformerEmbeddingModel) register_embedding_provider(nim.NIMEmbeddingModel) register_embedding_provider(nim.NVIDIAAIEndpointsEmbeddingModel) diff --git a/nemoguardrails/embeddings/providers/azureopenai.py b/nemoguardrails/embeddings/providers/azureopenai.py new file mode 100644 index 000000000..538e07245 --- /dev/null +++ b/nemoguardrails/embeddings/providers/azureopenai.py @@ -0,0 +1,104 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import os +from typing import List + +from .base import EmbeddingModel + + +def get_executor(): + from . import embeddings_executor + + return embeddings_executor + +class AzureEmbeddingModel(EmbeddingModel): + """Embedding model using Azure OpenAI. + + This class represents an embedding model that utilizes the Azure OpenAI API + for generating text embeddings. + + Args: + embedding_model (str): The name of the Azure OpenAI deployment model (e.g., "text-embedding-ada-002"). + """ + + engine_name = "AzureOpenAI" + + # Lookup table for model embedding dimensions + MODEL_DIMENSIONS = { + "text-embedding-ada-002": 1536, + # Add more models and their dimensions here if needed + } + + def __init__(self, embedding_model: str): + try: + from openai import AzureOpenAI + except ImportError: + raise ImportError( + "Could not import openai, please install it with " + "`pip install openai`." + ) + # Set Azure OpenAI API credentials + self.client = AzureOpenAI( + api_key=os.getenv("AZURE_OPENAI_API_KEY"), + api_version=os.getenv("AZURE_OPENAI_API_VERSION"), + azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), + ) + + self.embedding_model = embedding_model + self.embedding_size = self._get_embedding_dimension() + + def _get_embedding_dimension(self): + """Retrieve the embedding dimension for the specified model.""" + if self.embedding_model in self.MODEL_DIMENSIONS: + return self.MODEL_DIMENSIONS[self.embedding_model] + else: + # Perform a first encoding to get the embedding size + self.embedding_size = len(self.encode(["test"])[0]) + + async def encode_async(self, documents: List[str]) -> List[List[float]]: + """Asynchronously encode a list of documents into their corresponding embeddings. + + Args: + documents (List[str]): The list of documents to be encoded. + + Returns: + List[List[float]]: The list of embeddings, where each embedding is a list of floats. + """ + loop = asyncio.get_running_loop() + result = await loop.run_in_executor(get_executor(), self.encode, documents) + return result + + def encode(self, documents: List[str]) -> List[List[float]]: + """Encode a list of documents into their corresponding embeddings. + + Args: + documents (List[str]): The list of documents to be encoded. + + Returns: + List[List[float]]: The list of embeddings, where each embedding is a list of floats. + + Raises: + RuntimeError: If the API call fails. + """ + try: + response = self.client.embeddings.create( + model=self.embedding_model, input=documents + ) + embeddings = [record.embedding for record in response.data] + return embeddings + except Exception as e: + raise RuntimeError(f"Failed to retrieve embeddings: {e}") diff --git a/tests/test_configs/with_azureopenai_embeddings/config.co b/tests/test_configs/with_azureopenai_embeddings/config.co new file mode 100644 index 000000000..56035e40c --- /dev/null +++ b/tests/test_configs/with_azureopenai_embeddings/config.co @@ -0,0 +1,12 @@ +define user ask capabilities + "What can you do?" + "What can you help me with?" + "tell me what you can do" + "tell me about you" + +define bot inform capabilities + "I am an AI assistant that helps answer questions." + +define flow + user ask capabilities + bot inform capabilities diff --git a/tests/test_configs/with_azureopenai_embeddings/config.yml b/tests/test_configs/with_azureopenai_embeddings/config.yml new file mode 100644 index 000000000..718236894 --- /dev/null +++ b/tests/test_configs/with_azureopenai_embeddings/config.yml @@ -0,0 +1,8 @@ +models: + - type: main + engine: azure + model: gpt-4o + + - type: embeddings + engine: AzureOpenAI + model: text-embedding-ada-002 diff --git a/tests/test_embeddings_azureopenai.py b/tests/test_embeddings_azureopenai.py new file mode 100644 index 000000000..74f1c6d44 --- /dev/null +++ b/tests/test_embeddings_azureopenai.py @@ -0,0 +1,96 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest + +from nemoguardrails import LLMRails, RailsConfig + +try: + from nemoguardrails.embeddings.providers.azureopenai import AzureEmbeddingModel +except ImportError: + # Ignore this if running in test environment when azureopenai not installed. + AzureEmbeddingModel = None +CONFIGS_FOLDER = os.path.join(os.path.dirname(__file__), ".", "test_configs") + +LIVE_TEST_MODE = os.environ.get("LIVE_TEST") + + +@pytest.fixture +def app(): + """Load the configuration where we replace FastEmbed with AzureOpenAI.""" + config = RailsConfig.from_path( + os.path.join(CONFIGS_FOLDER, "with_azureopenai_embeddings") + ) + + return LLMRails(config) + + +@pytest.mark.skipif(not LIVE_TEST_MODE, reason="Not in live mode.") +def test_custom_llm_registration(app): + assert isinstance( + app.llm_generation_actions.flows_index._model, AzureEmbeddingModel + ) + + +@pytest.mark.skipif(not LIVE_TEST_MODE, reason="Not in live mode.") +@pytest.mark.asyncio +async def test_live_query(): + config = RailsConfig.from_path( + os.path.join(CONFIGS_FOLDER, "with_azureopenai_embeddings") + ) + app = LLMRails(config) + + result = await app.generate_async( + messages=[{"role": "user", "content": "tell me what you can do"}] + ) + + assert result == { + "role": "assistant", + "content": "I am an AI assistant that helps answer questions.", + } + + +@pytest.mark.skipif(not LIVE_TEST_MODE, reason="Not in live mode.") +@pytest.mark.asyncio +def test_live_query(app): + result = app.generate( + messages=[{"role": "user", "content": "tell me what you can do"}] + ) + + assert result == { + "role": "assistant", + "content": "I am an AI assistant that helps answer questions.", + } + + +@pytest.mark.skipif(not LIVE_TEST_MODE, reason="Not in live mode.") +def test_sync_embeddings(): + model = AzureEmbeddingModel("text-embedding-ada-002") + + result = model.encode(["test"]) + + assert len(result[0]) == 1536 + + +@pytest.mark.skipif(not LIVE_TEST_MODE, reason="Not in live mode.") +@pytest.mark.asyncio +async def test_async_embeddings(): + model = AzureEmbeddingModel("text-embedding-ada-002") + + result = await model.encode_async(["test"]) + + assert len(result[0]) == 1536