diff --git a/_examples/retrieval_qa/main.go b/_examples/retrieval_qa/main.go index 7f819c8..c3095dd 100644 --- a/_examples/retrieval_qa/main.go +++ b/_examples/retrieval_qa/main.go @@ -31,7 +31,7 @@ func main() { log.Fatal(err) } - result, err := chain.Run(context.Background(), retrievalQAChain, "Why don't scientists trust atoms?") + result, err := retrievalQAChain.Run(context.Background(), "Why don't scientists trust atoms?") if err != nil { log.Fatal(err) } diff --git a/callback/manager.go b/callback/manager.go index e07704a..0e0b6e7 100644 --- a/callback/manager.go +++ b/callback/manager.go @@ -102,3 +102,17 @@ func (m *Manager) OnChainEnd(outputs *schema.ChainValues) error { return nil } + +func (m *Manager) OnChainError(chainError error) error { + for _, c := range m.callbacks { + if m.verbose || c.AlwaysVerbose() { + if err := c.OnChainError(chainError); err != nil { + if c.RaiseError() { + return err + } + } + } + } + + return nil +} diff --git a/chain/chain.go b/chain/chain.go index 97bb7e0..8901389 100644 --- a/chain/chain.go +++ b/chain/chain.go @@ -4,46 +4,71 @@ import ( "context" "strings" + "github.com/hupe1980/golc/callback" "github.com/hupe1980/golc/schema" ) +type callbackOptions struct { + Callbacks []schema.Callback + Verbose bool +} + type callFunc func(ctx context.Context, inputs schema.ChainValues) (schema.ChainValues, error) -type chain struct { - callFunc callFunc - inputKeys []string - outputKeys []string +type baseChain struct { + chainName string + callFunc callFunc + inputKeys []string + outputKeys []string + memory schema.Memory + callbackOptions *callbackOptions } -func newChain(callFunc callFunc, inputKeys []string, outputKeys []string) *chain { - return &chain{ - callFunc: callFunc, - inputKeys: inputKeys, - outputKeys: outputKeys, +func (bc *baseChain) Call(ctx context.Context, inputs schema.ChainValues) (schema.ChainValues, error) { + cm := callback.NewManager(bc.callbackOptions.Callbacks, bc.callbackOptions.Verbose) + + if err := cm.OnChainStart(bc.chainName, &inputs); err != nil { + return nil, err } -} -func (c *chain) Call(ctx context.Context, inputs schema.ChainValues) (schema.ChainValues, error) { - return c.callFunc(ctx, inputs) + output, err := bc.callFunc(ctx, inputs) + if err != nil { + if cbError := cm.OnChainError(err); cbError != nil { + return nil, cbError + } + + return nil, err + } + + if err := cm.OnChainEnd(&output); err != nil { + return nil, err + } + + return output, nil } -func (c *chain) Run(ctx context.Context, input any) (string, error) { - if len(c.inputKeys) != 1 { +func (bc *baseChain) Run(ctx context.Context, input any) (string, error) { + if len(bc.inputKeys) != 1 { return "", ErrMultipleInputsInRun } - if len(c.outputKeys) != 1 { + if len(bc.outputKeys) != 1 { return "", ErrMultipleOutputsInRun } - inputValues := map[string]any{c.inputKeys[0]: input} + inputValues := map[string]any{bc.inputKeys[0]: input} + + // TODO + if bc.memory != nil { + _, _ = bc.memory.LoadMemoryVariables(inputValues) + } - outputValues, err := c.Call(ctx, inputValues) + outputValues, err := bc.Call(ctx, inputValues) if err != nil { return "", err } - outputValue, ok := outputValues[c.outputKeys[0]].(string) + outputValue, ok := outputValues[bc.outputKeys[0]].(string) if !ok { return "", ErrWrongOutputTypeInRun } @@ -51,7 +76,7 @@ func (c *chain) Run(ctx context.Context, input any) (string, error) { return strings.TrimSpace(outputValue), nil } -func (c *chain) Apply(ctx context.Context, inputs []schema.ChainValues) ([]schema.ChainValues, error) { +func (bc *baseChain) Apply(ctx context.Context, inputs []schema.ChainValues) ([]schema.ChainValues, error) { chainValues := []schema.ChainValues{} for _, input := range inputs { @@ -59,7 +84,7 @@ func (c *chain) Apply(ctx context.Context, inputs []schema.ChainValues) ([]schem case <-ctx.Done(): return nil, ctx.Err() default: - vals, err := c.Call(ctx, input) + vals, err := bc.Call(ctx, input) if err != nil { return nil, err } @@ -72,16 +97,11 @@ func (c *chain) Apply(ctx context.Context, inputs []schema.ChainValues) ([]schem } // InputKeys returns the expected input keys. -func (c *chain) InputKeys() []string { - return c.inputKeys +func (bc *baseChain) InputKeys() []string { + return bc.inputKeys } // OutputKeys returns the output keys the chain will return. -func (c *chain) OutputKeys() []string { - return c.outputKeys -} - -type callbackOptions struct { - Callbacks []schema.Callback - Verbose bool +func (bc *baseChain) OutputKeys() []string { + return bc.outputKeys } diff --git a/chain/conversation_chain.go b/chain/conversation_chain.go new file mode 100644 index 0000000..95fb8b6 --- /dev/null +++ b/chain/conversation_chain.go @@ -0,0 +1,57 @@ +package chain + +import ( + "context" + + "github.com/hupe1980/golc" + "github.com/hupe1980/golc/memory" + "github.com/hupe1980/golc/prompt" + "github.com/hupe1980/golc/schema" +) + +type ConversationChainOptions struct { + *callbackOptions + Memory schema.Memory + OutputParser schema.OutputParser[any] +} + +type ConversationChain struct { + *baseChain + llm schema.LLM + prompt *prompt.Template + opts ConversationChainOptions +} + +func NewConversationChain(llm schema.LLM, prompt *prompt.Template, optFns ...func(o *ConversationChainOptions)) (*ConversationChain, error) { + opts := ConversationChainOptions{ + Memory: memory.NewConversationBuffer(), + callbackOptions: &callbackOptions{ + Verbose: golc.Verbose, + }, + } + + for _, fn := range optFns { + fn(&opts) + } + + conversationChain := &ConversationChain{ + prompt: prompt, + llm: llm, + opts: opts, + } + + conversationChain.baseChain = &baseChain{ + chainName: "ConversationChain", + callFunc: conversationChain.call, + inputKeys: []string{"input"}, + outputKeys: []string{"response"}, + memory: opts.Memory, + callbackOptions: opts.callbackOptions, + } + + return conversationChain, nil +} + +func (c *ConversationChain) call(ctx context.Context, inputs schema.ChainValues) (schema.ChainValues, error) { + return nil, nil +} diff --git a/chain/llm_bash_chain.go b/chain/llm_bash_chain.go index 70cab53..1bd3eb6 100644 --- a/chain/llm_bash_chain.go +++ b/chain/llm_bash_chain.go @@ -4,6 +4,7 @@ import ( "context" "fmt" + "github.com/hupe1980/golc" "github.com/hupe1980/golc/integration" "github.com/hupe1980/golc/outputparser" "github.com/hupe1980/golc/prompt" @@ -31,21 +32,29 @@ That is the format. Begin! Question: {{.question}}` type LLMBashChainOptions struct { + *callbackOptions InputKey string OutputKey string } type LLMBashChain struct { - *chain + *baseChain llmChain *LLMChain bashProcess *integration.BashProcess opts LLMBashChainOptions } -func NewLLMBashChain(llmChain *LLMChain) (*LLMBashChain, error) { +func NewLLMBashChain(llmChain *LLMChain, optFns ...func(o *LLMBashChainOptions)) (*LLMBashChain, error) { opts := LLMBashChainOptions{ InputKey: "question", OutputKey: "answer", + callbackOptions: &callbackOptions{ + Verbose: golc.Verbose, + }, + } + + for _, fn := range optFns { + fn(&opts) } bp, err := integration.NewBashProcess() @@ -59,7 +68,13 @@ func NewLLMBashChain(llmChain *LLMChain) (*LLMBashChain, error) { opts: opts, } - bash.chain = newChain(bash.call, []string{opts.InputKey}, []string{opts.OutputKey}) + bash.baseChain = &baseChain{ + chainName: "LLMBashChain", + callFunc: bash.call, + inputKeys: []string{opts.InputKey}, + outputKeys: []string{opts.OutputKey}, + callbackOptions: opts.callbackOptions, + } return bash, nil } diff --git a/chain/llm_chain.go b/chain/llm_chain.go index e4b164e..77b0461 100644 --- a/chain/llm_chain.go +++ b/chain/llm_chain.go @@ -4,19 +4,19 @@ import ( "context" "github.com/hupe1980/golc" - "github.com/hupe1980/golc/callback" "github.com/hupe1980/golc/prompt" "github.com/hupe1980/golc/schema" ) type LLMChainOptions struct { - callbackOptions + *callbackOptions + Memory schema.Memory OutputKey string OutputParser schema.OutputParser[any] } type LLMChain struct { - *chain + *baseChain llm schema.LLM prompt *prompt.Template opts LLMChainOptions @@ -25,7 +25,7 @@ type LLMChain struct { func NewLLMChain(llm schema.LLM, prompt *prompt.Template, optFns ...func(o *LLMChainOptions)) (*LLMChain, error) { opts := LLMChainOptions{ OutputKey: "text", - callbackOptions: callbackOptions{ + callbackOptions: &callbackOptions{ Verbose: golc.Verbose, }, } @@ -40,17 +40,20 @@ func NewLLMChain(llm schema.LLM, prompt *prompt.Template, optFns ...func(o *LLMC opts: opts, } - llmChain.chain = newChain(llmChain.call, prompt.InputVariables(), []string{opts.OutputKey}) + llmChain.baseChain = &baseChain{ + chainName: "LLMChain", + callFunc: llmChain.call, + inputKeys: prompt.InputVariables(), + outputKeys: []string{opts.OutputKey}, + memory: opts.Memory, + callbackOptions: opts.callbackOptions, + } return llmChain, nil } -func (c *LLMChain) Type() string { - return "llm_chain" -} - -func (c *LLMChain) Predict(ctx context.Context, values schema.ChainValues) (string, error) { - output, err := c.Call(ctx, values) +func (c *LLMChain) Predict(ctx context.Context, inputs schema.ChainValues) (string, error) { + output, err := c.Call(ctx, inputs) if err != nil { return "", err } @@ -58,14 +61,8 @@ func (c *LLMChain) Predict(ctx context.Context, values schema.ChainValues) (stri return output[c.opts.OutputKey].(string), err } -func (c *LLMChain) call(ctx context.Context, values schema.ChainValues) (schema.ChainValues, error) { - cm := callback.NewManager(c.opts.Callbacks, c.opts.Verbose) - - if err := cm.OnChainStart("LLMChain", &values); err != nil { - return nil, err - } - - promptValue, err := c.prompt.FormatPrompt(values) +func (c *LLMChain) call(ctx context.Context, inputs schema.ChainValues) (schema.ChainValues, error) { + promptValue, err := c.prompt.FormatPrompt(inputs) if err != nil { return nil, err } @@ -77,17 +74,8 @@ func (c *LLMChain) call(ctx context.Context, values schema.ChainValues) (schema. return nil, err } - output, err := c.getFinalOutput(res.Generations[0]) - if err != nil { - return nil, err - } - - if err := cm.OnChainEnd(&schema.ChainValues{"outputs": output}); err != nil { - return nil, err - } - return schema.ChainValues{ - c.opts.OutputKey: output, + c.opts.OutputKey: c.getFinalOutput(res.Generations), }, nil } @@ -95,8 +83,12 @@ func (c *LLMChain) Prompt() *prompt.Template { return c.prompt } -func (c *LLMChain) getFinalOutput(generations []*schema.Generation) (any, error) { // nolint unparam - completion := generations[0].Text - // TODO Outputparser - return completion, nil +func (c *LLMChain) getFinalOutput(generations [][]*schema.Generation) string { + output := []string{} + for _, generation := range generations { + // Get the text of the top generated string. + output = append(output, generation[0].Text) + } + + return output[0] } diff --git a/chain/refine_documents.go b/chain/refine_documents.go index 31a62a7..2603a42 100644 --- a/chain/refine_documents.go +++ b/chain/refine_documents.go @@ -5,12 +5,14 @@ import ( "fmt" "strings" + "github.com/hupe1980/golc" "github.com/hupe1980/golc/prompt" "github.com/hupe1980/golc/schema" "github.com/hupe1980/golc/util" ) type RefineDocumentsOptions struct { + *callbackOptions InputKey string DocumentVariableName string InitialResponseName string @@ -19,18 +21,25 @@ type RefineDocumentsOptions struct { } type RefineDocumentsChain struct { - *chain + *baseChain llmChain *LLMChain refineLLMChain *LLMChain opts RefineDocumentsOptions } -func NewRefineDocumentsChain(llmChain *LLMChain, refineLLMChain *LLMChain) (*RefineDocumentsChain, error) { +func NewRefineDocumentsChain(llmChain *LLMChain, refineLLMChain *LLMChain, optFns ...func(o *RefineDocumentsOptions)) (*RefineDocumentsChain, error) { opts := RefineDocumentsOptions{ InputKey: "inputDocuments", DocumentVariableName: "context", InitialResponseName: "existingAnswer", OutputKey: "text", + callbackOptions: &callbackOptions{ + Verbose: golc.Verbose, + }, + } + + for _, fn := range optFns { + fn(&opts) } if opts.DocumentPrompt == nil { @@ -48,7 +57,13 @@ func NewRefineDocumentsChain(llmChain *LLMChain, refineLLMChain *LLMChain) (*Ref opts: opts, } - refine.chain = newChain(refine.call, []string{opts.InputKey}, llmChain.OutputKeys()) + refine.baseChain = &baseChain{ + chainName: "RefineDocumentsChain", + callFunc: refine.call, + inputKeys: []string{opts.InputKey}, + outputKeys: llmChain.OutputKeys(), + callbackOptions: opts.callbackOptions, + } return refine, nil } diff --git a/chain/retrieval_qa.go b/chain/retrieval_qa.go index 0ac50af..7a508ea 100644 --- a/chain/retrieval_qa.go +++ b/chain/retrieval_qa.go @@ -4,26 +4,35 @@ import ( "context" "fmt" + "github.com/hupe1980/golc" "github.com/hupe1980/golc/prompt" "github.com/hupe1980/golc/schema" ) type RetrievalQAOptions struct { + *callbackOptions InputKey string ReturnSourceDocuments bool } type RetrievalQA struct { - *chain + *baseChain stuffDocumentsChain *StuffDocumentsChain retriever schema.Retriever opts RetrievalQAOptions } -func NewRetrievalQA(stuffDocumentsChain *StuffDocumentsChain, retriever schema.Retriever) (*RetrievalQA, error) { +func NewRetrievalQA(stuffDocumentsChain *StuffDocumentsChain, retriever schema.Retriever, optFns ...func(o *RetrievalQAOptions)) (*RetrievalQA, error) { opts := RetrievalQAOptions{ InputKey: "query", ReturnSourceDocuments: false, + callbackOptions: &callbackOptions{ + Verbose: golc.Verbose, + }, + } + + for _, fn := range optFns { + fn(&opts) } qa := &RetrievalQA{ @@ -32,7 +41,13 @@ func NewRetrievalQA(stuffDocumentsChain *StuffDocumentsChain, retriever schema.R opts: opts, } - qa.chain = newChain(qa.call, []string{opts.InputKey}, stuffDocumentsChain.OutputKeys()) + qa.baseChain = &baseChain{ + chainName: "RetrievalQA", + callFunc: qa.call, + inputKeys: []string{opts.InputKey}, + outputKeys: stuffDocumentsChain.OutputKeys(), + callbackOptions: opts.callbackOptions, + } return qa, nil } diff --git a/chain/stuff_documents.go b/chain/stuff_documents.go index 0bd4a3e..7c2b924 100644 --- a/chain/stuff_documents.go +++ b/chain/stuff_documents.go @@ -5,20 +5,20 @@ import ( "fmt" "strings" - "github.com/hupe1980/golc/callback" + "github.com/hupe1980/golc" "github.com/hupe1980/golc/schema" "github.com/hupe1980/golc/util" ) type StuffDocumentsOptions struct { - callbackOptions + *callbackOptions InputKey string DocumentVariableName string Separator string } type StuffDocumentsChain struct { - *chain + *baseChain llmChain *LLMChain opts StuffDocumentsOptions } @@ -28,6 +28,9 @@ func NewStuffDocumentsChain(llmChain *LLMChain, optFns ...func(o *StuffDocuments InputKey: "inputDocuments", DocumentVariableName: "context", Separator: "\n\n", + callbackOptions: &callbackOptions{ + Verbose: golc.Verbose, + }, } for _, fn := range optFns { @@ -39,18 +42,18 @@ func NewStuffDocumentsChain(llmChain *LLMChain, optFns ...func(o *StuffDocuments opts: opts, } - stuff.chain = newChain(stuff.call, []string{opts.InputKey}, llmChain.OutputKeys()) + stuff.baseChain = &baseChain{ + chainName: "StuffDocumentsChain", + callFunc: stuff.call, + inputKeys: []string{opts.InputKey}, + outputKeys: llmChain.OutputKeys(), + callbackOptions: opts.callbackOptions, + } return stuff, nil } func (stuff *StuffDocumentsChain) call(ctx context.Context, values schema.ChainValues) (schema.ChainValues, error) { - cm := callback.NewManager(stuff.opts.Callbacks, stuff.opts.Verbose) - - if err := cm.OnChainStart("StuffDocumentsChain", &values); err != nil { - return nil, err - } - input, ok := values[stuff.opts.InputKey] if !ok { return nil, fmt.Errorf("%w: no value for inputKey %s", ErrInvalidInputValues, stuff.opts.InputKey) @@ -69,9 +72,5 @@ func (stuff *StuffDocumentsChain) call(ctx context.Context, values schema.ChainV inputValues := util.CopyMap(values) inputValues[stuff.opts.DocumentVariableName] = strings.Join(contents, stuff.opts.Separator) - if err := cm.OnChainEnd(&schema.ChainValues{"outputs": inputValues}); err != nil { - return nil, err - } - return stuff.llmChain.Call(ctx, inputValues) } diff --git a/chain/summarization.go b/chain/summarization.go index 873c396..9a1a22d 100644 --- a/chain/summarization.go +++ b/chain/summarization.go @@ -15,12 +15,12 @@ const stuffSummarizationTemplate = `Write a concise summary of the following: CONCISE SUMMARY:` type StuffSummarizationChainOptions struct { - callbackOptions + *callbackOptions } func NewStuffSummarizationChain(llm schema.LLM, optFns ...func(o *StuffSummarizationChainOptions)) (*StuffDocumentsChain, error) { opts := StuffSummarizationChainOptions{ - callbackOptions: callbackOptions{ + callbackOptions: &callbackOptions{ Verbose: golc.Verbose, }, } @@ -60,7 +60,7 @@ If the context isn't useful, return the original summary. REFINED SUMMARY:` type RefineSummarizationChainOptions struct { - callbackOptions + *callbackOptions } func NewRefineSummarizationChain(llm schema.LLM, optFns ...func(o *RefineSummarizationChainOptions)) (*RefineDocumentsChain, error) { @@ -87,10 +87,14 @@ func NewRefineSummarizationChain(llm schema.LLM, optFns ...func(o *RefineSummari return nil, err } - refineLLMChain, err := NewLLMChain(llm, refinePrompt) + refineLLMChain, err := NewLLMChain(llm, refinePrompt, func(o *LLMChainOptions) { + o.callbackOptions = opts.callbackOptions + }) if err != nil { return nil, err } - return NewRefineDocumentsChain(llmChain, refineLLMChain) + return NewRefineDocumentsChain(llmChain, refineLLMChain, func(o *RefineDocumentsOptions) { + o.callbackOptions = opts.callbackOptions + }) } diff --git a/chain/transform.go b/chain/transform.go index 11cc112..838e5d4 100644 --- a/chain/transform.go +++ b/chain/transform.go @@ -3,6 +3,7 @@ package chain import ( "context" + "github.com/hupe1980/golc" "github.com/hupe1980/golc/schema" ) @@ -11,21 +12,39 @@ var _ schema.Chain = (*TransformChain)(nil) type TransformFunc func(inputs schema.ChainValues) (schema.ChainValues, error) +type TransformChainOptions struct { + *callbackOptions +} + type TransformChain struct { - *chain - inputKeys []string - outputKeys []string - transform TransformFunc + *baseChain + transform TransformFunc + opts TransformChainOptions } -func NewTransformChain(inputKeys, outputKeys []string, transform TransformFunc) (*TransformChain, error) { +func NewTransformChain(inputKeys, outputKeys []string, transform TransformFunc, optFns ...func(o *TransformChainOptions)) (*TransformChain, error) { + opts := TransformChainOptions{ + callbackOptions: &callbackOptions{ + Verbose: golc.Verbose, + }, + } + + for _, fn := range optFns { + fn(&opts) + } + t := &TransformChain{ - inputKeys: inputKeys, - outputKeys: outputKeys, - transform: transform, + transform: transform, + opts: opts, } - t.chain = newChain(t.call, t.inputKeys, t.outputKeys) + t.baseChain = &baseChain{ + chainName: "TransformChain", + callFunc: t.call, + inputKeys: inputKeys, + outputKeys: outputKeys, + callbackOptions: opts.callbackOptions, + } return t, nil } diff --git a/chatmodel/anthropic.go b/chatmodel/anthropic.go index 97304c0..3c896b3 100644 --- a/chatmodel/anthropic.go +++ b/chatmodel/anthropic.go @@ -39,7 +39,7 @@ func NewAnthropic(apiKey string) (*Anthropic, error) { return a, nil } -func (a *Anthropic) generate(ctx context.Context, messages []schema.ChatMessage, optFns ...func(o *schema.GenerateOptions)) (*schema.LLMResult, error) { +func (a *Anthropic) generate(ctx context.Context, messages schema.ChatMessages, optFns ...func(o *schema.GenerateOptions)) (*schema.LLMResult, error) { res, err := a.client.Complete(ctx, &anthropic.CompletionRequest{ Model: a.opts.ModelName, MaxTokens: a.opts.MaxTokens, diff --git a/chatmodel/chatmodel.go b/chatmodel/chatmodel.go index aa93a24..2ea3d78 100644 --- a/chatmodel/chatmodel.go +++ b/chatmodel/chatmodel.go @@ -7,7 +7,7 @@ import ( "github.com/hupe1980/golc/util" ) -type GenerateFunc func(ctx context.Context, messages []schema.ChatMessage, optFns ...func(o *schema.GenerateOptions)) (*schema.LLMResult, error) +type GenerateFunc func(ctx context.Context, messages schema.ChatMessages, optFns ...func(o *schema.GenerateOptions)) (*schema.LLMResult, error) type ChatModel struct { generateFunc GenerateFunc @@ -47,7 +47,7 @@ func (b *ChatModel) GeneratePrompt(ctx context.Context, promptValues []schema.Pr func (b *ChatModel) Predict(ctx context.Context, text string, optFns ...func(o *schema.GenerateOptions)) (string, error) { message := schema.NewHumanChatMessage(text) - result, err := b.PredictMessages(ctx, []schema.ChatMessage{message}, optFns...) + result, err := b.PredictMessages(ctx, schema.ChatMessages{message}, optFns...) if err != nil { return "", err } @@ -55,7 +55,7 @@ func (b *ChatModel) Predict(ctx context.Context, text string, optFns ...func(o * return result.Text(), nil } -func (b *ChatModel) PredictMessages(ctx context.Context, messages []schema.ChatMessage, optFns ...func(o *schema.GenerateOptions)) (schema.ChatMessage, error) { +func (b *ChatModel) PredictMessages(ctx context.Context, messages schema.ChatMessages, optFns ...func(o *schema.GenerateOptions)) (schema.ChatMessage, error) { result, err := b.Generate(ctx, [][]schema.ChatMessage{messages}, optFns...) if err != nil { return nil, err diff --git a/chatmodel/openai.go b/chatmodel/openai.go index ea0b239..24c22bb 100644 --- a/chatmodel/openai.go +++ b/chatmodel/openai.go @@ -60,7 +60,7 @@ func NewOpenAI(apiKey string) (*OpenAI, error) { return o, nil } -func (o *OpenAI) generate(ctx context.Context, messages []schema.ChatMessage, optFns ...func(o *schema.GenerateOptions)) (*schema.LLMResult, error) { +func (o *OpenAI) generate(ctx context.Context, messages schema.ChatMessages, optFns ...func(o *schema.GenerateOptions)) (*schema.LLMResult, error) { openAIMessages := []openai.ChatCompletionMessage{} for _, message := range messages { diff --git a/llm/llm.go b/llm/llm.go index 3ab8a43..437cfa3 100644 --- a/llm/llm.go +++ b/llm/llm.go @@ -39,7 +39,7 @@ func (l *llm) Generate(ctx context.Context, prompts []string, optFns ...func(o * result, err := l.generateFunc(ctx, prompts, opts.Stop) if err != nil { - if cbErr := cm.OnLLMError(err); err != nil { + if cbErr := cm.OnLLMError(err); cbErr != nil { return nil, cbErr } @@ -70,8 +70,8 @@ func (l *llm) Predict(ctx context.Context, text string, optFns ...func(o *schema return result.Generations[0][0].Text, nil } -func (l *llm) PredictMessages(ctx context.Context, messages []schema.ChatMessage, optFns ...func(o *schema.GenerateOptions)) (schema.ChatMessage, error) { - text, err := schema.StringifyChatMessages(messages) +func (l *llm) PredictMessages(ctx context.Context, messages schema.ChatMessages, optFns ...func(o *schema.GenerateOptions)) (schema.ChatMessage, error) { + text, err := messages.Format() if err != nil { return nil, err } diff --git a/memory/buffer.go b/memory/buffer.go index 8cef2f3..8c35c9d 100644 --- a/memory/buffer.go +++ b/memory/buffer.go @@ -61,7 +61,7 @@ func (m *ConversationBuffer) LoadMemoryVariables(inputs map[string]any) (map[str }, nil } - buffer, err := schema.StringifyChatMessages(messages, func(o *schema.StringifyChatMessagesOptions) { + buffer, err := messages.Format(func(o *schema.StringifyChatMessagesOptions) { o.HumanPrefix = m.opts.HumanPrefix o.AIPrefix = m.opts.AIPrefix }) diff --git a/memory/buffer_test.go b/memory/buffer_test.go index 0a8c9f1..5257a93 100644 --- a/memory/buffer_test.go +++ b/memory/buffer_test.go @@ -20,7 +20,7 @@ func TestConversationBuffer(t *testing.T) { t.Run("LoadMemoryVariables", func(t *testing.T) { inputs := map[string]interface{}{} - messages := []schema.ChatMessage{ + messages := schema.ChatMessages{ schema.NewHumanChatMessage("Hello"), schema.NewAIChatMessage("Hi there"), } diff --git a/memory/chatmessagehistory/dynamodb.go b/memory/chatmessagehistory/dynamodb.go index 64ed985..0b40df8 100644 --- a/memory/chatmessagehistory/dynamodb.go +++ b/memory/chatmessagehistory/dynamodb.go @@ -32,7 +32,7 @@ func NewDynamoDB(client *dynamodb.Client, tableName, sessionID string) *DynamoDB } } -func (mh *DynamoDB) Messages() ([]schema.ChatMessage, error) { +func (mh *DynamoDB) Messages() (schema.ChatMessages, error) { sessionID, err := attributevalue.Marshal(mh.sessionID) if err != nil { return nil, err diff --git a/memory/chatmessagehistory/in_memory.go b/memory/chatmessagehistory/in_memory.go index 605b892..e240563 100644 --- a/memory/chatmessagehistory/in_memory.go +++ b/memory/chatmessagehistory/in_memory.go @@ -21,7 +21,7 @@ func NewInMemoryWithMessages(messages []schema.ChatMessage) *InMemory { } } -func (mh *InMemory) Messages() ([]schema.ChatMessage, error) { +func (mh *InMemory) Messages() (schema.ChatMessages, error) { return mh.messages, nil } diff --git a/prompt/prompt.go b/prompt/prompt.go index 13e00e5..3508173 100644 --- a/prompt/prompt.go +++ b/prompt/prompt.go @@ -19,8 +19,8 @@ func (v StringPromptValue) String() string { return string(v) } -func (v StringPromptValue) Messages() []schema.ChatMessage { - return []schema.ChatMessage{ +func (v StringPromptValue) Messages() schema.ChatMessages { + return schema.ChatMessages{ schema.NewHumanChatMessage(string(v)), } } diff --git a/schema/chat.go b/schema/chat.go index 3df090f..55e9699 100644 --- a/schema/chat.go +++ b/schema/chat.go @@ -80,7 +80,9 @@ type StringifyChatMessagesOptions struct { SystemPrefix string } -func StringifyChatMessages(messages []ChatMessage, optFns ...func(o *StringifyChatMessagesOptions)) (string, error) { +type ChatMessages []ChatMessage + +func (cm ChatMessages) Format(optFns ...func(o *StringifyChatMessagesOptions)) (string, error) { opts := StringifyChatMessagesOptions{ HumanPrefix: "Human", AIPrefix: "AI", @@ -93,7 +95,7 @@ func StringifyChatMessages(messages []ChatMessage, optFns ...func(o *StringifyCh result := []string{} - for _, message := range messages { + for _, message := range cm { var role string switch message.Type() { @@ -114,3 +116,38 @@ func StringifyChatMessages(messages []ChatMessage, optFns ...func(o *StringifyCh return strings.Join(result, "\n"), nil } + +// func StringifyChatMessages(messages []ChatMessage, optFns ...func(o *StringifyChatMessagesOptions)) (string, error) { +// opts := StringifyChatMessagesOptions{ +// HumanPrefix: "Human", +// AIPrefix: "AI", +// SystemPrefix: "System", +// } + +// for _, fn := range optFns { +// fn(&opts) +// } + +// result := []string{} + +// for _, message := range messages { +// var role string + +// switch message.Type() { +// case ChatMessageTypeHuman: +// role = opts.HumanPrefix +// case ChatMessageTypeAI: +// role = opts.AIPrefix +// case ChatMessageTypeSystem: +// role = opts.SystemPrefix +// case ChatMessageTypeGeneric: +// role = message.(GenericChatMessage).Role() +// default: +// return "", fmt.Errorf("unknown chat message type: %s", message.Type()) +// } + +// result = append(result, fmt.Sprintf("%s: %s", role, message.Text())) +// } + +// return strings.Join(result, "\n"), nil +// } diff --git a/schema/memory.go b/schema/memory.go index d65566f..f19758e 100644 --- a/schema/memory.go +++ b/schema/memory.go @@ -14,7 +14,7 @@ type Memory interface { type ChatMessageHistory interface { // Messages returns the messages stored in the store. - Messages() ([]ChatMessage, error) + Messages() (ChatMessages, error) // Add a user message to the store. AddUserMessage(text string) error // Add an AI message to the store. diff --git a/schema/model.go b/schema/model.go index 46a7264..ff8318f 100644 --- a/schema/model.go +++ b/schema/model.go @@ -23,13 +23,13 @@ type Chain interface { type PromptValue interface { String() string - Messages() []ChatMessage + Messages() ChatMessages } type Tokenizer interface { GetTokenIDs(text string) ([]int, error) GetNumTokens(text string) (int, error) - GetNumTokensFromMessage(messages []ChatMessage) (int, error) + GetNumTokensFromMessage(messages ChatMessages) (int, error) } type Callback interface { @@ -59,7 +59,7 @@ type LLM interface { Tokenizer GeneratePrompt(ctx context.Context, promptValues []PromptValue, optFns ...func(o *GenerateOptions)) (*LLMResult, error) Predict(ctx context.Context, text string, optFns ...func(o *GenerateOptions)) (string, error) - PredictMessages(ctx context.Context, messages []ChatMessage, optFns ...func(o *GenerateOptions)) (ChatMessage, error) + PredictMessages(ctx context.Context, messages ChatMessages, optFns ...func(o *GenerateOptions)) (ChatMessage, error) } // Embedder is the interface for creating vector embeddings from texts. diff --git a/tokenizer/openai.go b/tokenizer/openai.go index ed6bb2e..b9f23db 100644 --- a/tokenizer/openai.go +++ b/tokenizer/openai.go @@ -33,8 +33,8 @@ func (o *OpenAI) GetNumTokens(text string) (int, error) { return len(ids), nil } -func (o *OpenAI) GetNumTokensFromMessage(messages []schema.ChatMessage) (int, error) { - text, err := schema.StringifyChatMessages(messages) +func (o *OpenAI) GetNumTokensFromMessage(messages schema.ChatMessages) (int, error) { + text, err := messages.Format() if err != nil { return 0, err } diff --git a/tokenizer/simple.go b/tokenizer/simple.go index 35f107c..f587d2d 100644 --- a/tokenizer/simple.go +++ b/tokenizer/simple.go @@ -16,6 +16,6 @@ func (t *Simple) GetNumTokens(text string) (int, error) { return 0, nil } -func (t *Simple) GetNumTokensFromMessage(messages []schema.ChatMessage) (int, error) { +func (t *Simple) GetNumTokensFromMessage(messages schema.ChatMessages) (int, error) { return 0, nil }