Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community: add AzureAIContentSafetyChain #27480

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 194 additions & 0 deletions docs/docs/integrations/chains/azure_ai_content_safety.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# `AzureAIContentSafetyChain`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"> [Azure AI Content Safety Chain](https://learn.microsoft.com/python/api/overview/azure/ai-contentsafety-readme?view=azure-python) is a wrapper around\n",
"> the Azure AI Content Safety service, implemented in LangChain using the LangChain \n",
"> [Runnables](https://python.langchain.com/docs/how_to/lcel_cheatsheet/) base class to allow use in a Runnables Sequence.\n",
"\n",
"The Class can be used to stop or filter content based on the Azure AI Content Safety policy."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Example Usage"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Get the required imports, here we will use a `ChatPromptTemplate` for convenience and the `AzureChatOpenAI`, however, any LangChain integrated model will work in a chain."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"from langchain_community.chains.azure_content_safety_chain import (\n",
" AzureAIContentSafetyChain,\n",
" AzureHarmfulContentError,\n",
")\n",
"from langchain_core.prompts import ChatPromptTemplate\n",
"from langchain_openai import AzureChatOpenAI"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"moderate = AzureAIContentSafetyChain()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"model = AzureChatOpenAI(\n",
" openai_api_version=os.environ[\"OPENAI_API_VERSION\"],\n",
" azure_deployment=os.environ[\"COMPLETIONS_MODEL\"],\n",
" azure_endpoint=os.environ[\"AZURE_OPENAI_ENDPOINT\"],\n",
" api_key=os.environ[\"AZURE_OPENAI_API_KEY\"],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"prompt = ChatPromptTemplate.from_messages([(\"system\", \"repeat after me: {input}\")])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Combine the objects to create a LangChain RunnablesSequence"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"moderated_chain = moderate | prompt | model"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"response = moderated_chain.invoke({\"input\": \"I like you!\"})"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'I like you!'"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"response.content"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With harmful content"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Harmful content: I hate you!\n"
]
},
{
"ename": "AzureHarmfulContentError",
"evalue": "The input has breached Azure's Content Safety Policy",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAzureHarmfulContentError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[17], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m----> 2\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[43mmoderated_chain\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minvoke\u001b[49m\u001b[43m(\u001b[49m\u001b[43m{\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43minput\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mI hate you!\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m}\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m AzureHarmfulContentError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mHarmful content: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00me\u001b[38;5;241m.\u001b[39minput\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m)\n",
"File \u001b[0;32m~/.local/lib/python3.11/site-packages/langchain_core/runnables/base.py:3020\u001b[0m, in \u001b[0;36mRunnableSequence.invoke\u001b[0;34m(self, input, config, **kwargs)\u001b[0m\n\u001b[1;32m 3018\u001b[0m context\u001b[38;5;241m.\u001b[39mrun(_set_config_context, config)\n\u001b[1;32m 3019\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m i \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m-> 3020\u001b[0m \u001b[38;5;28minput\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[43mcontext\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstep\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minvoke\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3021\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 3022\u001b[0m \u001b[38;5;28minput\u001b[39m \u001b[38;5;241m=\u001b[39m context\u001b[38;5;241m.\u001b[39mrun(step\u001b[38;5;241m.\u001b[39minvoke, \u001b[38;5;28minput\u001b[39m, config)\n",
"File \u001b[0;32m~/.local/lib/python3.11/site-packages/langchain/chains/base.py:170\u001b[0m, in \u001b[0;36mChain.invoke\u001b[0;34m(self, input, config, **kwargs)\u001b[0m\n\u001b[1;32m 168\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mBaseException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 169\u001b[0m run_manager\u001b[38;5;241m.\u001b[39mon_chain_error(e)\n\u001b[0;32m--> 170\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m e\n\u001b[1;32m 171\u001b[0m run_manager\u001b[38;5;241m.\u001b[39mon_chain_end(outputs)\n\u001b[1;32m 173\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m include_run_info:\n",
"File \u001b[0;32m~/.local/lib/python3.11/site-packages/langchain/chains/base.py:160\u001b[0m, in \u001b[0;36mChain.invoke\u001b[0;34m(self, input, config, **kwargs)\u001b[0m\n\u001b[1;32m 157\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 158\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_validate_inputs(inputs)\n\u001b[1;32m 159\u001b[0m outputs \u001b[38;5;241m=\u001b[39m (\n\u001b[0;32m--> 160\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrun_manager\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrun_manager\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 161\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m new_arg_supported\n\u001b[1;32m 162\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call(inputs)\n\u001b[1;32m 163\u001b[0m )\n\u001b[1;32m 165\u001b[0m final_outputs: Dict[\u001b[38;5;28mstr\u001b[39m, Any] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprep_outputs(\n\u001b[1;32m 166\u001b[0m inputs, outputs, return_only_outputs\n\u001b[1;32m 167\u001b[0m )\n\u001b[1;32m 168\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mBaseException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n",
"File \u001b[0;32m/workspaces/langchain/docs/docs/integrations/chains/../../../../libs/community/langchain_community/chains/azure_content_safety_chain.py:161\u001b[0m, in \u001b[0;36mAzureAIContentSafetyChain._call\u001b[0;34m(self, inputs, run_manager)\u001b[0m\n\u001b[1;32m 158\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mclient\u001b[38;5;241m.\u001b[39manalyze_text(request)\n\u001b[1;32m 160\u001b[0m result \u001b[38;5;241m=\u001b[39m response\u001b[38;5;241m.\u001b[39mcategories_analysis\n\u001b[0;32m--> 161\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_detect_harmful_content\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtext\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mresult\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 163\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m {\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minput_key: output, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_key: output}\n",
"File \u001b[0;32m/workspaces/langchain/docs/docs/integrations/chains/../../../../libs/community/langchain_community/chains/azure_content_safety_chain.py:142\u001b[0m, in \u001b[0;36mAzureAIContentSafetyChain._detect_harmful_content\u001b[0;34m(self, text, results)\u001b[0m\n\u001b[1;32m 137\u001b[0m error_str \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 138\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThe input text contains harmful content \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 139\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124maccording to Azure OpenAI\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124ms content policy\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 140\u001b[0m )\n\u001b[1;32m 141\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39merror:\n\u001b[0;32m--> 142\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m AzureHarmfulContentError(\u001b[38;5;28minput\u001b[39m\u001b[38;5;241m=\u001b[39mtext)\n\u001b[1;32m 143\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 144\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m error_str\n",
"\u001b[0;31mAzureHarmfulContentError\u001b[0m: The input has breached Azure's Content Safety Policy"
]
}
],
"source": [
"try:\n",
" response = moderated_chain.invoke({\"input\": \"I hate you!\"})\n",
"except AzureHarmfulContentError as e:\n",
" print(f\"Harmful content: {e.input}\")\n",
" raise"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
1 change: 1 addition & 0 deletions libs/community/extended_testing_deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ anthropic>=0.3.11,<0.4
arxiv>=1.4,<2
assemblyai>=0.17.0,<0.18
atlassian-python-api>=3.36.0,<4
azure-ai-contentsafety>=1.0.0
azure-ai-documentintelligence>=1.0.0b1,<2
azure-identity>=1.15.0,<2
azure-search-documents==11.4.0
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
"""Pass input through an azure content safety resource."""

from typing import Any, Dict, List, Optional

from langchain.chains.base import Chain
from langchain_core.callbacks import (
CallbackManagerForChainRun,
)
from langchain_core.exceptions import LangChainException
from langchain_core.utils import get_from_dict_or_env
from pydantic import model_validator


class AzureHarmfulContentError(LangChainException):
"""Exception for handling harmful content detected
in input for a model or chain according to Azure's
content safety policy."""

def __init__(
self,
input: str,
):
"""Constructor

Args:
input (str): The input given by the user to the model.
"""
self.input = input
self.message = "The input has breached Azure's Content Safety Policy"
super().__init__(self.message)


class AzureAIContentSafetyChain(Chain):
"""
A wrapper for the Azure AI Content Safety API in a Runnable form.
Allows for harmful content detection and filtering before input is
provided to a model.

**Note**:
This Service will filter input that shows any sign of harmful content,
this is non-configurable.

Attributes:
error (bool): Whether to raise an error if harmful content is detected.
content_safety_key (Optional[str]): API key for Azure Content Safety.
content_safety_endpoint (Optional[str]): Endpoint URL for Azure Content Safety.

Setup:
1. Follow the instructions here to deploy Azure AI Content Safety:
https://learn.microsoft.com/azure/ai-services/content-safety/overview

2. Install ``langchain`` ``langchain_community`` and set the following
environment variables:

.. code-block:: bash

pip install -U langchain langchain-community

export AZURE_CONTENT_SAFETY_KEY="your-api-key"
export AZURE_CONTENT_SAFETY_ENDPOINT="https://your-endpoint.azure.com/"


Example Usage (with safe content):
.. code-block:: python

from langchain_community.chains import AzureAIContentSafetyChain
from langchain_openai import AzureChatOpenAI

moderate = AzureAIContentSafetyChain()
prompt = ChatPromptTemplate.from_messages([("system",
"repeat after me: {input}")])
model = AzureChatOpenAI()

moderated_chain = moderate | prompt | model

moderated_chain.invoke({"input": "Hey, How are you?"})

Example Usage (with harmful content):
.. code-block:: python

from langchain_community.chains import AzureAIContentSafetyChain
from langchain_openai import AzureChatOpenAI

moderate = AzureAIContentSafetyChain()
prompt = ChatPromptTemplate.from_messages([("system",
"repeat after me: {input}")])
model = AzureChatOpenAI()

moderated_chain = moderate | prompt | model

try:
response = moderated_chain.invoke({"input": "I hate you!"})
except AzureHarmfulContentError as e:
print(f'Harmful content: {e.input}')
raise
"""

client: Any = None #: :meta private:
error: bool = True
"""Whether or not to error if bad content was found."""
input_key: str = "input" #: :meta private:
output_key: str = "output" #: :meta private:
content_safety_key: Optional[str] = None
content_safety_endpoint: Optional[str] = None

@property
def input_keys(self) -> List[str]:
"""Expect input key.

:meta private:
"""
return [self.input_key]

@property
def output_keys(self) -> List[str]:
"""Return output key.

:meta private:
"""
return [self.output_key]

@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Any:
"""Validate that api key and python package exists in environment."""
content_safety_key = get_from_dict_or_env(
values, "content_safety_key", "CONTENT_SAFETY_API_KEY"
)
content_safety_endpoint = get_from_dict_or_env(
values, "content_safety_endpoint", "CONTENT_SAFETY_ENDPOINT"
)
try:
import azure.ai.contentsafety as sdk
from azure.core.credentials import AzureKeyCredential

values["client"] = sdk.ContentSafetyClient(
endpoint=content_safety_endpoint,
credential=AzureKeyCredential(content_safety_key),
)

except ImportError:
raise ImportError(
"azure-ai-contentsafety is not installed. "
"Run `pip install azure-ai-contentsafety` to install."
)
return values

def _detect_harmful_content(self, text: str, results: Any) -> str:
contains_harmful_content = False

for category in results:
if category["severity"] > 0:
contains_harmful_content = True

if contains_harmful_content:
error_str = (
"The input text contains harmful content "
"according to Azure OpenAI's content policy"
)
if self.error:
raise AzureHarmfulContentError(input=text)
else:
return error_str

return text

def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
text = inputs[self.input_key]

from azure.ai.contentsafety.models import AnalyzeTextOptions

request = AnalyzeTextOptions(text=text)
response = self.client.analyze_text(request)

result = response.categories_analysis
output = self._detect_harmful_content(text, result)

return {self.input_key: output, self.output_key: output}
Loading
Loading