diff --git a/examples/openai_moderation_chain/main.go b/examples/openai_moderation/main.go similarity index 76% rename from examples/openai_moderation_chain/main.go rename to examples/openai_moderation/main.go index 208f10f..90eede8 100644 --- a/examples/openai_moderation_chain/main.go +++ b/examples/openai_moderation/main.go @@ -7,12 +7,12 @@ import ( "os" "github.com/hupe1980/golc" - "github.com/hupe1980/golc/chain" + "github.com/hupe1980/golc/moderation" "github.com/hupe1980/golc/schema" ) func main() { - moderationChain, err := chain.NewOpenAIModeration(os.Getenv("OPENAI_API_KEY")) + moderationChain, err := moderation.NewOpenAI(os.Getenv("OPENAI_API_KEY")) if err != nil { log.Fatal(err) } diff --git a/chain/openai_moderation.go b/moderation/openai.go similarity index 68% rename from chain/openai_moderation.go rename to moderation/openai.go index f0bbe32..4698e37 100644 --- a/chain/openai_moderation.go +++ b/moderation/openai.go @@ -1,4 +1,4 @@ -package chain +package moderation import ( "context" @@ -10,8 +10,8 @@ import ( "github.com/sashabaranov/go-openai" ) -// Compile time check to ensure OpenAIModeration satisfies the Chain interface. -var _ schema.Chain = (*OpenAIModeration)(nil) +// Compile time check to ensure OpenAI satisfies the Chain interface. +var _ schema.Chain = (*OpenAI)(nil) // OpenAIClient is an interface representing an OpenAI client that can make moderation requests. type OpenAIClient interface { @@ -23,8 +23,8 @@ type OpenAIClient interface { // OpenAIModerateFunc is a function type for handling the moderation response from OpenAI. type OpenAIModerateFunc func(id, model string, result openai.Result) (schema.ChainValues, error) -// OpenAIModerationOptions contains options for configuring the OpenAIModeration chain. -type OpenAIModerationOptions struct { +// OpenAIOptions contains options for configuring the OpenAI chain. +type OpenAIOptions struct { // CallbackOptions embeds CallbackOptions to include the verbosity setting and callbacks. *schema.CallbackOptions // ModelName is the name of the OpenAI model to use for moderation. @@ -37,21 +37,21 @@ type OpenAIModerationOptions struct { OpenAIModerateFunc OpenAIModerateFunc } -// OpenAIModeration represents a chain that performs moderation using the OpenAI API. -type OpenAIModeration struct { +// OpenAI represents a chain that performs moderation using the OpenAI API. +type OpenAI struct { client OpenAIClient - opts OpenAIModerationOptions + opts OpenAIOptions } -// NewOpenAIModeration creates a new instance of the OpenAIModeration chain using the provided API key and options. -func NewOpenAIModeration(apiKey string, optFns ...func(o *OpenAIModerationOptions)) (*OpenAIModeration, error) { +// NewOpenAI creates a new instance of the OpenAI chain using the provided API key and options. +func NewOpenAI(apiKey string, optFns ...func(o *OpenAIOptions)) (*OpenAI, error) { client := openai.NewClient(apiKey) - return NewOpenAIModerationFromClient(client, optFns...) + return NewOpenAIFromClient(client, optFns...) } -// NewOpenAIModerationFromClient creates a new instance of the OpenAIModeration chain with the given OpenAI client and options. -func NewOpenAIModerationFromClient(client OpenAIClient, optFns ...func(o *OpenAIModerationOptions)) (*OpenAIModeration, error) { - opts := OpenAIModerationOptions{ +// NewOpenAIFromClient creates a new instance of the OpenAI chain with the given OpenAI client and options. +func NewOpenAIFromClient(client OpenAIClient, optFns ...func(o *OpenAIOptions)) (*OpenAI, error) { + opts := OpenAIOptions{ CallbackOptions: &schema.CallbackOptions{ Verbose: golc.Verbose, }, @@ -76,7 +76,7 @@ func NewOpenAIModerationFromClient(client OpenAIClient, optFns ...func(o *OpenAI } } - return &OpenAIModeration{ + return &OpenAI{ client: client, opts: opts, }, nil @@ -84,7 +84,7 @@ func NewOpenAIModerationFromClient(client OpenAIClient, optFns ...func(o *OpenAI // Call executes the openai moderation chain with the given context and inputs. // It returns the outputs of the chain or an error, if any. -func (c *OpenAIModeration) Call(ctx context.Context, inputs schema.ChainValues, optFns ...func(o *schema.CallOptions)) (schema.ChainValues, error) { +func (c *OpenAI) Call(ctx context.Context, inputs schema.ChainValues, optFns ...func(o *schema.CallOptions)) (schema.ChainValues, error) { opts := schema.CallOptions{ CallbackManger: &callback.NoopManager{}, } @@ -116,31 +116,31 @@ func (c *OpenAIModeration) Call(ctx context.Context, inputs schema.ChainValues, } // Memory returns the memory associated with the chain. -func (c *OpenAIModeration) Memory() schema.Memory { +func (c *OpenAI) Memory() schema.Memory { return nil } // Type returns the type of the chain. -func (c *OpenAIModeration) Type() string { +func (c *OpenAI) Type() string { return "OpenAIModeration" } // Verbose returns the verbosity setting of the chain. -func (c *OpenAIModeration) Verbose() bool { +func (c *OpenAI) Verbose() bool { return c.opts.CallbackOptions.Verbose } // Callbacks returns the callbacks associated with the chain. -func (c *OpenAIModeration) Callbacks() []schema.Callback { +func (c *OpenAI) Callbacks() []schema.Callback { return c.opts.CallbackOptions.Callbacks } // InputKeys returns the expected input keys. -func (c *OpenAIModeration) InputKeys() []string { +func (c *OpenAI) InputKeys() []string { return []string{c.opts.InputKey} } // OutputKeys returns the output keys the chain will return. -func (c *OpenAIModeration) OutputKeys() []string { +func (c *OpenAI) OutputKeys() []string { return []string{c.opts.OutputKey} } diff --git a/chain/openai_moderation_test.go b/moderation/openai_test.go similarity index 91% rename from chain/openai_moderation_test.go rename to moderation/openai_test.go index 1c724ed..2f34bdc 100644 --- a/chain/openai_moderation_test.go +++ b/moderation/openai_test.go @@ -1,4 +1,4 @@ -package chain +package moderation import ( "context" @@ -10,7 +10,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestOpenAIModeration(t *testing.T) { +func TestOpenAI(t *testing.T) { // Test cases testCases := []struct { name string @@ -43,7 +43,7 @@ func TestOpenAIModeration(t *testing.T) { Results: []openai.Result{{Flagged: tc.flagged}}, }, } - chain, err := NewOpenAIModerationFromClient(fakeClient) + chain, err := NewOpenAIFromClient(fakeClient) require.NoError(t, err) // Test @@ -76,7 +76,7 @@ func TestOpenAIModeration(t *testing.T) { }, } - chain, err := NewOpenAIModerationFromClient(fakeClient, func(o *OpenAIModerationOptions) { + chain, err := NewOpenAIFromClient(fakeClient, func(o *OpenAIOptions) { o.OpenAIModerateFunc = func(id, model string, result openai.Result) (schema.ChainValues, error) { if result.Flagged { return nil, errors.New("custom content policy violation") @@ -112,7 +112,7 @@ func TestOpenAIModeration(t *testing.T) { }, } - chain, err := NewOpenAIModerationFromClient(fakeClient, func(o *OpenAIModerationOptions) { + chain, err := NewOpenAIFromClient(fakeClient, func(o *OpenAIOptions) { o.OpenAIModerateFunc = func(id, model string, result openai.Result) (schema.ChainValues, error) { if result.Flagged { return nil, errors.New("custom content policy violation")