Skip to content

Commit

Permalink
(fix) litellm.amoderation - support using `model=openai/omni-modera…
Browse files Browse the repository at this point in the history
…tion-latest`, `model=omni-moderation-latest`, `model=None` (BerriAI#7475)

* test_moderation_endpoint

* fix litellm.amoderation
  • Loading branch information
ishaan-jaff authored Dec 30, 2024
1 parent 24dd655 commit a003af6
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 30 deletions.
22 changes: 20 additions & 2 deletions litellm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.realtime_api.main import _realtime_health_check
from litellm.secret_managers.main import get_secret_str
from litellm.types.router import GenericLiteLLMParams
from litellm.utils import (
CustomStreamWrapper,
Usage,
Expand Down Expand Up @@ -4314,7 +4315,11 @@ def moderation(

@client
async def amoderation(
input: str, model: Optional[str] = None, api_key: Optional[str] = None, **kwargs
input: str,
model: Optional[str] = None,
api_key: Optional[str] = None,
custom_llm_provider: Optional[str] = None,
**kwargs,
):
from openai import AsyncOpenAI

Expand All @@ -4335,6 +4340,20 @@ async def amoderation(
)
else:
_openai_client = openai_client

optional_params = GenericLiteLLMParams(**kwargs)
try:
model, _custom_llm_provider, _dynamic_api_key, _dynamic_api_base = (
litellm.get_llm_provider(
model=model or "",
custom_llm_provider=custom_llm_provider,
api_base=optional_params.api_base,
api_key=optional_params.api_key,
)
)
except litellm.BadRequestError:
# `model` is optional field for moderation - get_llm_provider will throw BadRequestError if model is not set / not recognized
pass
if model is not None:
response = await _openai_client.moderations.create(input=input, model=model)
else:
Expand Down Expand Up @@ -5095,7 +5114,6 @@ def speech(
aspeech=aspeech,
)
elif custom_llm_provider == "vertex_ai" or custom_llm_provider == "vertex_ai_beta":
from litellm.types.router import GenericLiteLLMParams

generic_optional_params = GenericLiteLLMParams(**kwargs)

Expand Down
33 changes: 5 additions & 28 deletions litellm/proxy/proxy_config.yaml
Original file line number Diff line number Diff line change
@@ -1,34 +1,11 @@
model_list:
- model_name: openai/*
- model_name: "*"
litellm_params:
model: openai/*
model: "openai/*"
api_key: os.environ/OPENAI_API_KEY
- model_name: anthropic/*
- model_name: "openai/*"
litellm_params:
model: anthropic/*
api_key: os.environ/ANTHROPIC_API_KEY
- model_name: bedrock/*
litellm_params:
model: bedrock/*

guardrails:
- guardrail_name: "bedrock-pre-guard"
litellm_params:
guardrail: bedrock # supported values: "aporia", "bedrock", "lakera"
mode: "during_call"
guardrailIdentifier: ff6ujrregl1q # your guardrail ID on bedrock
guardrailVersion: "DRAFT" # your guardrail version on bedrock
mode: "post_call"
guardrailIdentifier: ff6ujrregl1q
guardrailVersion: "DRAFT"
guardrail_info:
params:
- name: "toxicity_score"
type: "float"
description: "Score between 0-1 indicating content toxicity level"
- name: "pii_detection"
type: "boolean"


model: "openai/*"
api_key: os.environ/OPENAI_API_KEY
litellm_settings:
callbacks: ["datadog"]
31 changes: 31 additions & 0 deletions tests/router_unit_tests/test_router_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,37 @@ async def test_rerank_endpoint(model_list):
RerankResponse.model_validate(response)


@pytest.mark.asyncio()
@pytest.mark.parametrize(
"model", ["omni-moderation-latest", "openai/omni-moderation-latest", None]
)
async def test_moderation_endpoint(model):
litellm.set_verbose = True
router = Router(
model_list=[
{
"model_name": "openai/*",
"litellm_params": {
"model": "openai/*",
},
},
{
"model_name": "*",
"litellm_params": {
"model": "openai/*",
},
},
]
)

if model is None:
response = await router.amoderation(input="hello this is a test")
else:
response = await router.amoderation(model=model, input="hello this is a test")

print("moderation response: ", response)


@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_aaaaatext_completion_endpoint(model_list, sync_mode):
Expand Down

0 comments on commit a003af6

Please sign in to comment.