Skip to content

Commit

Permalink
Merge pull request #35 from yorinasub17/features/moderations
Browse files Browse the repository at this point in the history
Add support for Moderations API endpoint
  • Loading branch information
tylermann authored Mar 29, 2023
2 parents f4f8f0f + 66b828f commit e348aa5
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 0 deletions.
29 changes: 29 additions & 0 deletions gpt3.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ const (
TextEmbeddingAda002 = "text-embedding-ada-002"
)

const (
TextModerationLatest = "text-moderation-latest"
TextModerationStable = "text-moderation-stable"
)

const (
defaultBaseURL = "https://api.openai.com/v1"
defaultUserAgent = "go-gpt3"
Expand Down Expand Up @@ -104,6 +109,10 @@ type Client interface {

// Returns an embedding using the provided request.
Embeddings(ctx context.Context, request EmbeddingsRequest) (*EmbeddingsResponse, error)

// Moderation performs a moderation check on the given text against an OpenAI classifier to determine whether the
// provided content complies with OpenAI's usage policies.
Moderation(ctx context.Context, request ModerationRequest) (*ModerationResponse, error)
}

type client struct {
Expand Down Expand Up @@ -376,6 +385,26 @@ func (c *client) Embeddings(ctx context.Context, request EmbeddingsRequest) (*Em
return &output, nil
}

// Moderation performs a moderation check on the given text against an OpenAI classifier.
//
// See: https://platform.openai.com/docs/api-reference/moderations/create
func (c *client) Moderation(ctx context.Context, request ModerationRequest) (*ModerationResponse, error) {
req, err := c.newRequest(ctx, "POST", "/moderations", request)
if err != nil {
return nil, err
}
resp, err := c.performRequest(req)
if err != nil {
return nil, err
}

output := ModerationResponse{}
if err := getResponseObject(resp, &output); err != nil {
return nil, err
}
return &output, nil
}

func (c *client) performRequest(req *http.Request) (*http.Response, error) {
resp, err := c.httpClient.Do(req)
if err != nil {
Expand Down
38 changes: 38 additions & 0 deletions gpt3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,13 @@ func TestRequestCreationFails(t *testing.T) {
},
"Post \"https://api.openai.com/v1/embeddings\": request error",
},
{
"Moderation",
func() (interface{}, error) {
return client.Moderation(ctx, gpt3.ModerationRequest{})
},
"Post \"https://api.openai.com/v1/moderations\": request error",
},
}

for _, tc := range testCases {
Expand Down Expand Up @@ -312,6 +319,37 @@ func TestResponses(t *testing.T) {
},
},
},
{
"Moderation",
func() (interface{}, error) {
return client.Moderation(ctx, gpt3.ModerationRequest{})
},
&gpt3.ModerationResponse{
ID: "123",
Model: "text-moderation-001",
Results: []gpt3.ModerationResult{{
Flagged: false,
Categories: gpt3.ModerationCategoryResult{
Hate: false,
HateThreatening: false,
SelfHarm: false,
Sexual: false,
SexualMinors: false,
Violence: false,
ViolenceGraphic: false,
},
CategoryScores: gpt3.ModerationCategoryScores{
Hate: 0.22714105248451233,
HateThreatening: 0.22714105248451233,
SelfHarm: 0.005232391878962517,
Sexual: 0.01407341007143259,
SexualMinors: 0.0038522258400917053,
Violence: 0.009223177433013916,
ViolenceGraphic: 0.036865197122097015,
},
}},
},
},
}

for _, tc := range testCases {
Expand Down
45 changes: 45 additions & 0 deletions models.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,3 +286,48 @@ type SearchResponse struct {
Data []SearchData `json:"data"`
Object string `json:"object"`
}

// ModerationRequest is a request for the moderation API.
type ModerationRequest struct {
// Input is the input text that should be classified. Required.
Input string `json:"input"`
// Model is the content moderation model to use. If not specified, will default to OpenAI API defaults, which is
// currently "text-moderation-latest".
Model string `json:"model,omitempty"`
}

// ModerationCategoryResult shows the categories that the moderation classifier flagged the input text for.
type ModerationCategoryResult struct {
Hate bool `json:"hate"`
HateThreatening bool `json:"hate/threatening"`
SelfHarm bool `json:"self-harm"`
Sexual bool `json:"sexual"`
SexualMinors bool `json:"sexual/minors"`
Violence bool `json:"violence"`
ViolenceGraphic bool `json:"violence/graphic"`
}

// ModerationCategoryScores shows the classifier scores for each moderation category.
type ModerationCategoryScores struct {
Hate float32 `json:"hate"`
HateThreatening float32 `json:"hate/threatening"`
SelfHarm float32 `json:"self-harm"`
Sexual float32 `json:"sexual"`
SexualMinors float32 `json:"sexual/minors"`
Violence float32 `json:"violence"`
ViolenceGraphic float32 `json:"violence/graphic"`
}

// ModerationResult represents a single moderation classification result returned by the moderation API.
type ModerationResult struct {
Flagged bool `json:"flagged"`
Categories ModerationCategoryResult `json:"categories"`
CategoryScores ModerationCategoryScores `json:"category_scores"`
}

// ModerationResponse is the full response from a request to the moderation API.
type ModerationResponse struct {
ID string `json:"id"`
Model string `json:"model"`
Results []ModerationResult `json:"results"`
}

0 comments on commit e348aa5

Please sign in to comment.