Skip to content

Commit

Permalink
Refactor openai moderation
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Nov 23, 2023
1 parent 2b18bd1 commit 0ea1a8d
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ import (
"os"

"github.com/hupe1980/golc"
"github.com/hupe1980/golc/chain"
"github.com/hupe1980/golc/moderation"
"github.com/hupe1980/golc/schema"
)

func main() {
moderationChain, err := chain.NewOpenAIModeration(os.Getenv("OPENAI_API_KEY"))
moderationChain, err := moderation.NewOpenAI(os.Getenv("OPENAI_API_KEY"))
if err != nil {
log.Fatal(err)
}
Expand Down
44 changes: 22 additions & 22 deletions chain/openai_moderation.go → moderation/openai.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package chain
package moderation

import (
"context"
Expand All @@ -10,8 +10,8 @@ import (
"github.com/sashabaranov/go-openai"
)

// Compile time check to ensure OpenAIModeration satisfies the Chain interface.
var _ schema.Chain = (*OpenAIModeration)(nil)
// Compile time check to ensure OpenAI satisfies the Chain interface.
var _ schema.Chain = (*OpenAI)(nil)

// OpenAIClient is an interface representing an OpenAI client that can make moderation requests.
type OpenAIClient interface {
Expand All @@ -23,8 +23,8 @@ type OpenAIClient interface {
// OpenAIModerateFunc is a function type for handling the moderation response from OpenAI.
type OpenAIModerateFunc func(id, model string, result openai.Result) (schema.ChainValues, error)

// OpenAIModerationOptions contains options for configuring the OpenAIModeration chain.
type OpenAIModerationOptions struct {
// OpenAIOptions contains options for configuring the OpenAI chain.
type OpenAIOptions struct {
// CallbackOptions embeds CallbackOptions to include the verbosity setting and callbacks.
*schema.CallbackOptions
// ModelName is the name of the OpenAI model to use for moderation.
Expand All @@ -37,21 +37,21 @@ type OpenAIModerationOptions struct {
OpenAIModerateFunc OpenAIModerateFunc
}

// OpenAIModeration represents a chain that performs moderation using the OpenAI API.
type OpenAIModeration struct {
// OpenAI represents a chain that performs moderation using the OpenAI API.
type OpenAI struct {
client OpenAIClient
opts OpenAIModerationOptions
opts OpenAIOptions
}

// NewOpenAIModeration creates a new instance of the OpenAIModeration chain using the provided API key and options.
func NewOpenAIModeration(apiKey string, optFns ...func(o *OpenAIModerationOptions)) (*OpenAIModeration, error) {
// NewOpenAI creates a new instance of the OpenAI chain using the provided API key and options.
func NewOpenAI(apiKey string, optFns ...func(o *OpenAIOptions)) (*OpenAI, error) {
client := openai.NewClient(apiKey)
return NewOpenAIModerationFromClient(client, optFns...)
return NewOpenAIFromClient(client, optFns...)
}

// NewOpenAIModerationFromClient creates a new instance of the OpenAIModeration chain with the given OpenAI client and options.
func NewOpenAIModerationFromClient(client OpenAIClient, optFns ...func(o *OpenAIModerationOptions)) (*OpenAIModeration, error) {
opts := OpenAIModerationOptions{
// NewOpenAIFromClient creates a new instance of the OpenAI chain with the given OpenAI client and options.
func NewOpenAIFromClient(client OpenAIClient, optFns ...func(o *OpenAIOptions)) (*OpenAI, error) {
opts := OpenAIOptions{
CallbackOptions: &schema.CallbackOptions{
Verbose: golc.Verbose,
},
Expand All @@ -76,15 +76,15 @@ func NewOpenAIModerationFromClient(client OpenAIClient, optFns ...func(o *OpenAI
}
}

return &OpenAIModeration{
return &OpenAI{
client: client,
opts: opts,
}, nil
}

// Call executes the openai moderation chain with the given context and inputs.
// It returns the outputs of the chain or an error, if any.
func (c *OpenAIModeration) Call(ctx context.Context, inputs schema.ChainValues, optFns ...func(o *schema.CallOptions)) (schema.ChainValues, error) {
func (c *OpenAI) Call(ctx context.Context, inputs schema.ChainValues, optFns ...func(o *schema.CallOptions)) (schema.ChainValues, error) {
opts := schema.CallOptions{
CallbackManger: &callback.NoopManager{},
}
Expand Down Expand Up @@ -116,31 +116,31 @@ func (c *OpenAIModeration) Call(ctx context.Context, inputs schema.ChainValues,
}

// Memory returns the memory associated with the chain.
func (c *OpenAIModeration) Memory() schema.Memory {
func (c *OpenAI) Memory() schema.Memory {
return nil
}

// Type returns the type of the chain.
func (c *OpenAIModeration) Type() string {
func (c *OpenAI) Type() string {
return "OpenAIModeration"
}

// Verbose returns the verbosity setting of the chain.
func (c *OpenAIModeration) Verbose() bool {
func (c *OpenAI) Verbose() bool {
return c.opts.CallbackOptions.Verbose
}

// Callbacks returns the callbacks associated with the chain.
func (c *OpenAIModeration) Callbacks() []schema.Callback {
func (c *OpenAI) Callbacks() []schema.Callback {
return c.opts.CallbackOptions.Callbacks
}

// InputKeys returns the expected input keys.
func (c *OpenAIModeration) InputKeys() []string {
func (c *OpenAI) InputKeys() []string {
return []string{c.opts.InputKey}
}

// OutputKeys returns the output keys the chain will return.
func (c *OpenAIModeration) OutputKeys() []string {
func (c *OpenAI) OutputKeys() []string {
return []string{c.opts.OutputKey}
}
10 changes: 5 additions & 5 deletions chain/openai_moderation_test.go → moderation/openai_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package chain
package moderation

import (
"context"
Expand All @@ -10,7 +10,7 @@ import (
"github.com/stretchr/testify/require"
)

func TestOpenAIModeration(t *testing.T) {
func TestOpenAI(t *testing.T) {
// Test cases
testCases := []struct {
name string
Expand Down Expand Up @@ -43,7 +43,7 @@ func TestOpenAIModeration(t *testing.T) {
Results: []openai.Result{{Flagged: tc.flagged}},
},
}
chain, err := NewOpenAIModerationFromClient(fakeClient)
chain, err := NewOpenAIFromClient(fakeClient)
require.NoError(t, err)

// Test
Expand Down Expand Up @@ -76,7 +76,7 @@ func TestOpenAIModeration(t *testing.T) {
},
}

chain, err := NewOpenAIModerationFromClient(fakeClient, func(o *OpenAIModerationOptions) {
chain, err := NewOpenAIFromClient(fakeClient, func(o *OpenAIOptions) {
o.OpenAIModerateFunc = func(id, model string, result openai.Result) (schema.ChainValues, error) {
if result.Flagged {
return nil, errors.New("custom content policy violation")
Expand Down Expand Up @@ -112,7 +112,7 @@ func TestOpenAIModeration(t *testing.T) {
},
}

chain, err := NewOpenAIModerationFromClient(fakeClient, func(o *OpenAIModerationOptions) {
chain, err := NewOpenAIFromClient(fakeClient, func(o *OpenAIOptions) {
o.OpenAIModerateFunc = func(id, model string, result openai.Result) (schema.ChainValues, error) {
if result.Flagged {
return nil, errors.New("custom content policy violation")
Expand Down

0 comments on commit 0ea1a8d

Please sign in to comment.