-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(routes): add moderation route support (#61)
* feat(routes): add moderation route support * chore(docs): reference moderation support in readme + coverage
- Loading branch information
1 parent
e5754f8
commit 70d76f0
Showing
11 changed files
with
268 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import openai | ||
import pytest | ||
|
||
import openai_responses | ||
from openai_responses import OpenAIMock | ||
|
||
|
||
@pytest.fixture() | ||
def client(): | ||
return openai.Client(api_key="sk-fake123") | ||
|
||
|
||
@openai_responses.mock() | ||
def test_create_moderation_returns_default( | ||
openai_mock: OpenAIMock, client: openai.Client | ||
): | ||
expected_prefix = "modr-" | ||
expect_model = "text-moderation-007" | ||
|
||
actual = client.moderations.create(input="Test input") | ||
|
||
assert actual.id.startswith( | ||
expected_prefix | ||
), f"Expected id to start with {expected_prefix}" | ||
assert actual.model == expect_model | ||
assert len(actual.results) == 0 | ||
assert openai_mock.moderations.create.route.call_count == 1 | ||
|
||
|
||
@openai_responses.mock() | ||
def test_create_moderation_applies_defaults_if_partial_response_provided( | ||
openai_mock: OpenAIMock, client: openai.Client | ||
): | ||
openai_mock.moderations.create.response = { | ||
"results": [ | ||
{ | ||
"flagged": True, | ||
"categories": {"harassment": True, "violence/graphic": True}, | ||
"category_scores": {"harassment": 0.9, "violence/graphic": 0.8}, | ||
} | ||
] | ||
} | ||
|
||
actual = client.moderations.create(input="Test input") | ||
|
||
assert len(actual.results) == 1 | ||
assert actual.results[0].model_dump(by_alias=True) == { | ||
"flagged": True, | ||
"categories": { | ||
"harassment": True, | ||
"harassment/threatening": False, | ||
"hate": False, | ||
"hate/threatening": False, | ||
"self-harm": False, | ||
"self-harm/instructions": False, | ||
"self-harm/intent": False, | ||
"sexual": False, | ||
"sexual/minors": False, | ||
"violence": False, | ||
"violence/graphic": True, | ||
}, | ||
"category_scores": { | ||
"harassment": 0.9, | ||
"harassment/threatening": 0.0, | ||
"hate": 0.0, | ||
"hate/threatening": 0.0, | ||
"self-harm": 0.0, | ||
"self-harm/instructions": 0.0, | ||
"self-harm/intent": 0.0, | ||
"sexual": 0.0, | ||
"sexual/minors": 0.0, | ||
"violence": 0.0, | ||
"violence/graphic": 0.8, | ||
}, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
from openai.types.moderation import Moderation, Categories, CategoryScores | ||
from openai.types.moderation_create_response import ModerationCreateResponse | ||
|
||
import httpx | ||
import respx | ||
|
||
from ._base import StatelessRoute | ||
|
||
from .._types.partials.moderation import ( | ||
PartialModerationCreateResponse, | ||
PartialCategories, | ||
PartialCategoryScores, | ||
) | ||
|
||
from .._utils.serde import model_parse | ||
from .._utils.faker import faker | ||
|
||
__all__ = ["ModerationCreateRoute"] | ||
|
||
_default_categories: PartialCategories = { | ||
"harassment": False, | ||
"harassment/threatening": False, | ||
"hate": False, | ||
"hate/threatening": False, | ||
"self-harm": False, | ||
"self-harm/instructions": False, | ||
"self-harm/intent": False, | ||
"sexual": False, | ||
"sexual/minors": False, | ||
"violence": False, | ||
"violence/graphic": False, | ||
} | ||
|
||
_default_category_scores: PartialCategoryScores = { | ||
"harassment": 0.0, | ||
"harassment/threatening": 0.0, | ||
"hate": 0.0, | ||
"hate/threatening": 0.0, | ||
"self-harm": 0.0, | ||
"self-harm/instructions": 0.0, | ||
"self-harm/intent": 0.0, | ||
"sexual": 0.0, | ||
"sexual/minors": 0.0, | ||
"violence": 0.0, | ||
"violence/graphic": 0.0, | ||
} | ||
|
||
|
||
class ModerationCreateRoute( | ||
StatelessRoute[ModerationCreateResponse, PartialModerationCreateResponse] | ||
): | ||
def __init__(self, router: respx.MockRouter) -> None: | ||
super().__init__(route=router.post(url__regex="/moderations"), status_code=200) | ||
|
||
@staticmethod | ||
def _build( | ||
partial: PartialModerationCreateResponse, request: httpx.Request | ||
) -> ModerationCreateResponse: | ||
partial_results = partial.get("results", []) | ||
moderation_results = [ | ||
Moderation( | ||
categories=model_parse( | ||
Categories, | ||
_default_categories | ||
| partial_result.get("categories", _default_categories), | ||
), | ||
category_scores=model_parse( | ||
CategoryScores, | ||
_default_category_scores | ||
| partial_result.get("category_scores", _default_category_scores), | ||
), | ||
flagged=partial_result.get("flagged", False), | ||
) | ||
for partial_result in partial_results | ||
] | ||
|
||
return ModerationCreateResponse( | ||
id=partial.get("id", faker.moderation.id()), | ||
model=partial.get("model", "text-moderation-007"), | ||
results=moderation_results, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from typing import List, TypedDict | ||
from typing_extensions import NotRequired | ||
|
||
__all__ = [ | ||
"PartialModerationCreateResponse", | ||
"PartialCategories", | ||
"PartialCategoryScores", | ||
] | ||
|
||
PartialCategories = TypedDict( | ||
"PartialCategories", | ||
{ | ||
"harassment": NotRequired[bool], | ||
"harassment/threatening": NotRequired[bool], | ||
"hate": NotRequired[bool], | ||
"hate/threatening": NotRequired[bool], | ||
"self-harm": NotRequired[bool], | ||
"self-harm/instructions": NotRequired[bool], | ||
"self-harm/intent": NotRequired[bool], | ||
"sexual": NotRequired[bool], | ||
"sexual/minors": NotRequired[bool], | ||
"violence": NotRequired[bool], | ||
"violence/graphic": NotRequired[bool], | ||
}, | ||
) | ||
|
||
PartialCategoryScores = TypedDict( | ||
"PartialCategoryScores", | ||
{ | ||
"harassment": NotRequired[float], | ||
"harassment/threatening": NotRequired[float], | ||
"hate": NotRequired[float], | ||
"hate/threatening": NotRequired[float], | ||
"self-harm": NotRequired[float], | ||
"self-harm/instructions": NotRequired[float], | ||
"self-harm/intent": NotRequired[float], | ||
"sexual": NotRequired[float], | ||
"sexual/minors": NotRequired[float], | ||
"violence": NotRequired[float], | ||
"violence/graphic": NotRequired[float], | ||
}, | ||
) | ||
|
||
|
||
class PartialModeration(TypedDict): | ||
categories: NotRequired[PartialCategories] | ||
category_scores: NotRequired[PartialCategoryScores] | ||
flagged: NotRequired[bool] | ||
|
||
|
||
class PartialModerationCreateResponse(TypedDict): | ||
id: NotRequired[str] | ||
model: NotRequired[str] | ||
results: NotRequired[List[PartialModeration]] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from typing import Optional | ||
|
||
import httpx | ||
|
||
from openai.types.moderation_create_response import ModerationCreateResponse | ||
|
||
from ._base import _generic_builder | ||
from ..._routes.moderation import ModerationCreateRoute | ||
from ..._types.partials.moderation import PartialModerationCreateResponse | ||
|
||
__all__ = ["moderation_create_response_from_create_request"] | ||
|
||
|
||
def moderation_create_response_from_create_request( | ||
request: httpx.Request, | ||
*, | ||
extra: Optional[PartialModerationCreateResponse] = None, | ||
) -> ModerationCreateResponse: | ||
return _generic_builder(ModerationCreateRoute, request, extra) |