From f2bfa79e7d3dedb8a9ff24cafe5ea843296c0b4d Mon Sep 17 00:00:00 2001 From: Christopher Speller Date: Fri, 14 Jun 2024 16:23:11 -0700 Subject: [PATCH 1/5] Function use trace --- server/ai/prompts.go | 20 ++++----------- server/ai/tools.go | 44 ++++++++++++++++++++++++++++++--- server/api_channel.go | 2 +- server/api_post.go | 2 +- server/built_in_tools.go | 6 +++++ server/meeting_summarization.go | 4 +-- server/plugin.go | 2 +- server/service.go | 8 +++--- 8 files changed, 60 insertions(+), 28 deletions(-) diff --git a/server/ai/prompts.go b/server/ai/prompts.go index 46a72a19..97e71933 100644 --- a/server/ai/prompts.go +++ b/server/ai/prompts.go @@ -9,11 +9,8 @@ import ( "errors" ) -type BuiltInToolsFunc func(isDM bool) []Tool - type Prompts struct { - templates *template.Template - getBuiltInTools BuiltInToolsFunc + templates *template.Template } const PromptExtension = "tmpl" @@ -40,15 +37,14 @@ const ( PromptFindOpenQuestionsSince = "find_open_questions_since" ) -func NewPrompts(input fs.FS, getBuiltInTools BuiltInToolsFunc) (*Prompts, error) { +func NewPrompts(input fs.FS) (*Prompts, error) { templates, err := template.ParseFS(input, "ai/prompts/*") if err != nil { return nil, fmt.Errorf("unable to parse prompt templates: %w", err) } return &Prompts{ - templates: templates, - getBuiltInTools: getBuiltInTools, + templates: templates, }, nil } @@ -56,17 +52,11 @@ func withPromptExtension(filename string) string { return filename + "." + PromptExtension } -func (p *Prompts) getDefaultTools(isDMWithBot bool) ToolStore { - tools := NewToolStore() - tools.AddTools(p.getBuiltInTools(isDMWithBot)) - return tools -} - -func (p *Prompts) ChatCompletion(templateName string, context ConversationContext) (BotConversation, error) { +func (p *Prompts) ChatCompletion(templateName string, context ConversationContext, tools ToolStore) (BotConversation, error) { conversation := BotConversation{ Posts: []Post{}, Context: context, - Tools: p.getDefaultTools(context.IsDMWithBot()), + Tools: tools, } template := p.templates.Lookup(withPromptExtension(templateName)) diff --git a/server/ai/tools.go b/server/ai/tools.go index 4a089e10..db3b7958 100644 --- a/server/ai/tools.go +++ b/server/ai/tools.go @@ -1,6 +1,7 @@ package ai import ( + "encoding/json" "errors" ) @@ -14,12 +15,28 @@ type Tool struct { type ToolArgumentGetter func(args any) error type ToolStore struct { - tools map[string]Tool + tools map[string]Tool + log TraceLog + doTrace bool } -func NewToolStore() ToolStore { +type TraceLog interface { + Info(message string, keyValuePairs ...any) +} + +func NewNoTools() ToolStore { + return ToolStore{ + tools: make(map[string]Tool), + log: nil, + doTrace: false, + } +} + +func NewToolStore(log TraceLog, doTrace bool) ToolStore { return ToolStore{ - tools: make(map[string]Tool), + tools: make(map[string]Tool), + log: log, + doTrace: doTrace, } } @@ -32,9 +49,12 @@ func (s *ToolStore) AddTools(tools []Tool) { func (s *ToolStore) ResolveTool(name string, argsGetter ToolArgumentGetter, context ConversationContext) (string, error) { tool, ok := s.tools[name] if !ok { + s.TraceUnknown(name, argsGetter) return "", errors.New("unknown tool " + name) } - return tool.Resolver(context, argsGetter) + results, err := tool.Resolver(context, argsGetter) + s.TraceResolved(name, argsGetter, results) + return results, err } func (s *ToolStore) GetTools() []Tool { @@ -44,3 +64,19 @@ func (s *ToolStore) GetTools() []Tool { } return result } + +func (s *ToolStore) TraceUnknown(name string, argsGetter ToolArgumentGetter) { + if s.log != nil && s.doTrace { + var raw json.RawMessage + argsGetter(raw) + s.log.Info("unknown tool called", "name", name, "args", string(raw)) + } +} + +func (s *ToolStore) TraceResolved(name string, argsGetter ToolArgumentGetter, result string) { + if s.log != nil && s.doTrace { + var raw json.RawMessage + argsGetter(raw) + s.log.Info("tool resolved", "name", name, "args", string(raw), "result", result) + } +} diff --git a/server/api_channel.go b/server/api_channel.go index f4abe5a9..a826bae8 100644 --- a/server/api_channel.go +++ b/server/api_channel.go @@ -103,7 +103,7 @@ func (p *Plugin) handleSince(c *gin.Context) { return } - prompt, err := p.prompts.ChatCompletion(promptPreset, context) + prompt, err := p.prompts.ChatCompletion(promptPreset, context, p.getDefaultToolsStore(context.IsDMWithBot())) if err != nil { c.AbortWithError(http.StatusInternalServerError, err) return diff --git a/server/api_post.go b/server/api_post.go index eb52a620..1fc436e2 100644 --- a/server/api_post.go +++ b/server/api_post.go @@ -60,7 +60,7 @@ func (p *Plugin) handleReact(c *gin.Context) { context := p.MakeConversationContext(bot, user, channel, post) context.PromptParameters = map[string]string{"Message": post.Message} - prompt, err := p.prompts.ChatCompletion(ai.PromptEmojiSelect, context) + prompt, err := p.prompts.ChatCompletion(ai.PromptEmojiSelect, context, ai.NewNoTools()) if err != nil { c.AbortWithError(http.StatusInternalServerError, err) return diff --git a/server/built_in_tools.go b/server/built_in_tools.go index d36d19dc..b6eabe78 100644 --- a/server/built_in_tools.go +++ b/server/built_in_tools.go @@ -431,3 +431,9 @@ func (p *Plugin) getBuiltInTools(isDM bool) []ai.Tool { return builtInTools } + +func (p *Plugin) getDefaultToolsStore(isDM bool) ai.ToolStore { + store := ai.NewToolStore(&p.pluginAPI.Log, p.getConfiguration().EnableLLMTrace) + store.AddTools(p.getBuiltInTools(isDM)) + return store +} diff --git a/server/meeting_summarization.go b/server/meeting_summarization.go index 3ca7ce0d..c2730a0e 100644 --- a/server/meeting_summarization.go +++ b/server/meeting_summarization.go @@ -272,7 +272,7 @@ func (p *Plugin) summarizeTranscription(bot *Bot, transcription *subtitles.Subti p.pluginAPI.Log.Debug("Split into chunks", "chunks", len(chunks)) for _, chunk := range chunks { context.PromptParameters = map[string]string{"TranscriptionChunk": chunk} - summarizeChunkPrompt, err := p.prompts.ChatCompletion(ai.PromptSummarizeChunk, context) + summarizeChunkPrompt, err := p.prompts.ChatCompletion(ai.PromptSummarizeChunk, context, p.getDefaultToolsStore(context.IsDMWithBot())) if err != nil { return nil, fmt.Errorf("unable to get summarize chunk prompt: %w", err) } @@ -291,7 +291,7 @@ func (p *Plugin) summarizeTranscription(bot *Bot, transcription *subtitles.Subti } context.PromptParameters = map[string]string{"Transcription": llmFormattedTranscription, "IsChunked": fmt.Sprintf("%t", isChunked)} - summaryPrompt, err := p.prompts.ChatCompletion(ai.PromptMeetingSummary, context) + summaryPrompt, err := p.prompts.ChatCompletion(ai.PromptMeetingSummary, context, p.getDefaultToolsStore(context.IsDMWithBot())) if err != nil { return nil, fmt.Errorf("unable to get meeting summary prompt: %w", err) } diff --git a/server/plugin.go b/server/plugin.go index 0a060c28..96cf177b 100644 --- a/server/plugin.go +++ b/server/plugin.go @@ -112,7 +112,7 @@ func (p *Plugin) OnActivate() error { } var err error - p.prompts, err = ai.NewPrompts(promptsFolder, p.getBuiltInTools) + p.prompts, err = ai.NewPrompts(promptsFolder) if err != nil { return err } diff --git a/server/service.go b/server/service.go index a7f97e8b..a0e628f2 100644 --- a/server/service.go +++ b/server/service.go @@ -47,7 +47,7 @@ func (p *Plugin) processUserRequestToBot(bot *Bot, context ai.ConversationContex } func (p *Plugin) newConversation(bot *Bot, context ai.ConversationContext) error { - conversation, err := p.prompts.ChatCompletion(ai.PromptDirectMessageQuestion, context) + conversation, err := p.prompts.ChatCompletion(ai.PromptDirectMessageQuestion, context, p.getDefaultToolsStore(context.IsDMWithBot())) if err != nil { return err } @@ -128,7 +128,7 @@ func (p *Plugin) continueConversation(bot *Bot, threadData *ThreadData, context return nil, err } } else { - prompt, err := p.prompts.ChatCompletion(ai.PromptDirectMessageQuestion, context) + prompt, err := p.prompts.ChatCompletion(ai.PromptDirectMessageQuestion, context, p.getDefaultToolsStore(context.IsDMWithBot())) if err != nil { return nil, err } @@ -151,7 +151,7 @@ func (p *Plugin) continueThreadConversation(bot *Bot, questionThreadData *Thread originalThread := formatThread(originalThreadData) context.PromptParameters = map[string]string{"Thread": originalThread} - prompt, err := p.prompts.ChatCompletion(ai.PromptSummarizeThread, context) + prompt, err := p.prompts.ChatCompletion(ai.PromptSummarizeThread, context, p.getDefaultToolsStore(context.IsDMWithBot())) if err != nil { return nil, err } @@ -177,7 +177,7 @@ func (p *Plugin) summarizePost(bot *Bot, postIDToSummarize string, context ai.Co formattedThread := formatThread(threadData) context.PromptParameters = map[string]string{"Thread": formattedThread} - prompt, err := p.prompts.ChatCompletion(ai.PromptSummarizeThread, context) + prompt, err := p.prompts.ChatCompletion(ai.PromptSummarizeThread, context, p.getDefaultToolsStore(context.IsDMWithBot())) if err != nil { return nil, err } From 2731ba3965cfc7d86e8cea1397d4b94c2346ebb8 Mon Sep 17 00:00:00 2001 From: Christopher Speller Date: Mon, 17 Jun 2024 15:59:01 -0700 Subject: [PATCH 2/5] Working functions update --- go.mod | 2 +- go.sum | 6 +- server/ai/openai/openai.go | 109 +++++++++++++++++++++++++------------ 3 files changed, 78 insertions(+), 39 deletions(-) diff --git a/go.mod b/go.mod index e8cfcdd3..cbcf6e3b 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,7 @@ require ( github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.19.1 github.com/r3labs/sse/v2 v2.10.0 - github.com/sashabaranov/go-openai v1.24.0 + github.com/sashabaranov/go-openai v1.25.0 github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.8.4 golang.org/x/text v0.16.0 diff --git a/go.sum b/go.sum index b56dccba..5ea8b9d8 100644 --- a/go.sum +++ b/go.sum @@ -223,8 +223,10 @@ github.com/r3labs/sse/v2 v2.10.0/go.mod h1:Igau6Whc+F17QUgML1fYe1VPZzTV6EMCnYktE github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= -github.com/sashabaranov/go-openai v1.24.0 h1:4H4Pg8Bl2RH/YSnU8DYumZbuHnnkfioor/dtNlB20D4= -github.com/sashabaranov/go-openai v1.24.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= +github.com/sashabaranov/go-openai v1.24.1 h1:DWK95XViNb+agQtuzsn+FyHhn3HQJ7Va8z04DQDJ1MI= +github.com/sashabaranov/go-openai v1.24.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= +github.com/sashabaranov/go-openai v1.25.0 h1:3h3DtJ55zQJqc+BR4y/iTcPhLk4pewJpyO+MXW2RdW0= +github.com/sashabaranov/go-openai v1.25.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY= github.com/shurcooL/events v0.0.0-20181021180414-410e4ca65f48/go.mod h1:5u70Mqkb5O5cxEA8nxTsgrgLehJeAw6Oc4Ab1c/P1HM= diff --git a/server/ai/openai/openai.go b/server/ai/openai/openai.go index f3a21b2e..28fcea5a 100644 --- a/server/ai/openai/openai.go +++ b/server/ai/openai/openai.go @@ -88,12 +88,12 @@ func New(llmService ai.ServiceConfig, metricsService metrics.LLMetrics) *OpenAI func modifyCompletionRequestWithConversation(request openaiClient.ChatCompletionRequest, conversation ai.BotConversation) openaiClient.ChatCompletionRequest { request.Messages = postsToChatCompletionMessages(conversation.Posts) - request.Functions = toolsToFunctionDefinitions(conversation.Tools.GetTools()) //nolint:all + request.Tools = toolsToOpenAITools(conversation.Tools.GetTools()) return request } -func toolsToFunctionDefinitions(tools []ai.Tool) []openaiClient.FunctionDefinition { - result := make([]openaiClient.FunctionDefinition, 0, len(tools)) +func toolsToOpenAITools(tools []ai.Tool) []openaiClient.Tool { + result := make([]openaiClient.Tool, 0, len(tools)) schemaMaker := jsonschema.Reflector{ Anonymous: true, @@ -102,10 +102,13 @@ func toolsToFunctionDefinitions(tools []ai.Tool) []openaiClient.FunctionDefiniti for _, tool := range tools { schema := schemaMaker.Reflect(tool.Schema) - result = append(result, openaiClient.FunctionDefinition{ - Name: tool.Name, - Description: tool.Description, - Parameters: schema, + result = append(result, openaiClient.Tool{ + Type: openaiClient.ToolTypeFunction, + Function: &openaiClient.FunctionDefinition{ + Name: tool.Name, + Description: tool.Description, + Parameters: schema, + }, }) } @@ -183,18 +186,10 @@ func createFunctionArrgmentResolver(jsonArgs string) ai.ToolArgumentGetter { } } -func (s *OpenAI) handleStreamFunctionCall(request openaiClient.ChatCompletionRequest, conversation ai.BotConversation, name, arguments string) (openaiClient.ChatCompletionRequest, error) { - toolResult, err := conversation.Tools.ResolveTool(name, createFunctionArrgmentResolver(arguments), conversation.Context) - if err != nil { - fmt.Println("Error resolving function: ", err) - } - request.Messages = append(request.Messages, openaiClient.ChatCompletionMessage{ - Role: openaiClient.ChatMessageRoleFunction, - Name: name, - Content: toolResult, - }) - - return request, nil +type ToolBufferElement struct { + id strings.Builder + name strings.Builder + args strings.Builder } func (s *OpenAI) streamResultToChannels(request openaiClient.ChatCompletionRequest, conversation ai.BotConversation, output chan<- string, errChan chan<- error) { @@ -236,9 +231,8 @@ func (s *OpenAI) streamResultToChannels(request openaiClient.ChatCompletionReque defer stream.Close() - // Buffering in the case of a function call. - functionName := strings.Builder{} - functionArguments := strings.Builder{} + // Buffering in the case of tool use + var toolsBuffer map[int]*ToolBufferElement for { response, err := stream.Recv() if errors.Is(err, io.EOF) { @@ -266,11 +260,11 @@ func (s *OpenAI) streamResultToChannels(request openaiClient.ChatCompletionReque // Not done yet, keep going case openaiClient.FinishReasonStop: return - case openaiClient.FinishReasonFunctionCall: + case openaiClient.FinishReasonToolCalls: // Verify OpenAI functions are not recursing too deep. numFunctionCalls := 0 for i := len(request.Messages) - 1; i >= 0; i-- { - if request.Messages[i].Role == openaiClient.ChatMessageRoleFunction { + if request.Messages[i].Role == openaiClient.ChatMessageRoleTool { numFunctionCalls++ } else { break @@ -281,26 +275,69 @@ func (s *OpenAI) streamResultToChannels(request openaiClient.ChatCompletionReque return } - // Call ourselves again with the result of the function call - recursiveRequest, err := s.handleStreamFunctionCall(request, conversation, functionName.String(), functionArguments.String()) - if err != nil { - errChan <- err - return + tools := []openaiClient.ToolCall{} + for i, tool := range toolsBuffer { + name := tool.name.String() + arguments := tool.args.String() + toolID := tool.id.String() + num := i + tools = append(tools, openaiClient.ToolCall{ + Function: openaiClient.FunctionCall{ + Name: name, + Arguments: arguments, + }, + ID: toolID, + Index: &num, + Type: openaiClient.ToolTypeFunction, + }) + } + + request.Messages = append(request.Messages, openaiClient.ChatCompletionMessage{ + Role: openaiClient.ChatMessageRoleAssistant, + ToolCalls: tools, + }) + + for _, tool := range toolsBuffer { + name := tool.name.String() + arguments := tool.args.String() + toolID := tool.id.String() + toolResult, err := conversation.Tools.ResolveTool(name, createFunctionArrgmentResolver(arguments), conversation.Context) + if err != nil { + fmt.Printf("Error resolving function %s: %s", name, err) + } + request.Messages = append(request.Messages, openaiClient.ChatCompletionMessage{ + Role: openaiClient.ChatMessageRoleTool, + Name: name, + Content: toolResult, + ToolCallID: toolID, + }) } - s.streamResultToChannels(recursiveRequest, conversation, output, errChan) + + // Call ourselves again with the result of the function call + s.streamResultToChannels(request, conversation, output, errChan) return default: fmt.Printf("Unknown finish reason: %s", response.Choices[0].FinishReason) return } - // Keep track of any function call received - if response.Choices[0].Delta.FunctionCall != nil { - if response.Choices[0].Delta.FunctionCall.Name != "" { - functionName.WriteString(response.Choices[0].Delta.FunctionCall.Name) + delta := response.Choices[0].Delta + numTools := len(delta.ToolCalls) + if numTools != 0 { + if toolsBuffer == nil { + toolsBuffer = make(map[int]*ToolBufferElement) } - if response.Choices[0].Delta.FunctionCall.Arguments != "" { - functionArguments.WriteString(response.Choices[0].Delta.FunctionCall.Arguments) + for _, toolCall := range delta.ToolCalls { + if toolCall.Index == nil { + continue + } + toolIndex := *toolCall.Index + if toolsBuffer[toolIndex] == nil { + toolsBuffer[toolIndex] = &ToolBufferElement{} + } + toolsBuffer[toolIndex].name.WriteString(toolCall.Function.Name) + toolsBuffer[toolIndex].args.WriteString(toolCall.Function.Arguments) + toolsBuffer[toolIndex].id.WriteString(toolCall.ID) } } From b2d33808dd5aa08ec381e32b87f83fdd229c05ea Mon Sep 17 00:00:00 2001 From: Christopher Speller Date: Tue, 18 Jun 2024 11:22:09 -0700 Subject: [PATCH 3/5] Tests --- server/plugin_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/plugin_test.go b/server/plugin_test.go index 9294b52b..3360a89c 100644 --- a/server/plugin_test.go +++ b/server/plugin_test.go @@ -31,7 +31,7 @@ func SetupTestEnvironment(t *testing.T) *TestEnvironment { } var promptErr error - p.prompts, promptErr = ai.NewPrompts(promptsFolder, p.getBuiltInTools) + p.prompts, promptErr = ai.NewPrompts(promptsFolder) require.NoError(t, promptErr) p.ffmpegPath = "" From bd291e799cfedd138111c717693f30c32055b669 Mon Sep 17 00:00:00 2001 From: Christopher Speller Date: Tue, 18 Jun 2024 11:31:06 -0700 Subject: [PATCH 4/5] Cleanup --- server/ai/openai/openai.go | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/server/ai/openai/openai.go b/server/ai/openai/openai.go index 28fcea5a..4a5db9f1 100644 --- a/server/ai/openai/openai.go +++ b/server/ai/openai/openai.go @@ -275,6 +275,7 @@ func (s *OpenAI) streamResultToChannels(request openaiClient.ChatCompletionReque return } + // Transfer the buffered tools into tool calls tools := []openaiClient.ToolCall{} for i, tool := range toolsBuffer { name := tool.name.String() @@ -292,15 +293,17 @@ func (s *OpenAI) streamResultToChannels(request openaiClient.ChatCompletionReque }) } + // Add the tool calls to the request request.Messages = append(request.Messages, openaiClient.ChatCompletionMessage{ Role: openaiClient.ChatMessageRoleAssistant, ToolCalls: tools, }) - for _, tool := range toolsBuffer { - name := tool.name.String() - arguments := tool.args.String() - toolID := tool.id.String() + // Resolve the tools and create messages for each + for _, tool := range tools { + name := tool.Function.Name + arguments := tool.Function.Arguments + toolID := tool.ID toolResult, err := conversation.Tools.ResolveTool(name, createFunctionArrgmentResolver(arguments), conversation.Context) if err != nil { fmt.Printf("Error resolving function %s: %s", name, err) From 21de3a34744cf76f32b9ed456e84983d58a12dec Mon Sep 17 00:00:00 2001 From: Christopher Speller Date: Tue, 18 Jun 2024 11:36:39 -0700 Subject: [PATCH 5/5] mod tidy --- go.sum | 2 -- 1 file changed, 2 deletions(-) diff --git a/go.sum b/go.sum index 5ea8b9d8..396a2fd5 100644 --- a/go.sum +++ b/go.sum @@ -223,8 +223,6 @@ github.com/r3labs/sse/v2 v2.10.0/go.mod h1:Igau6Whc+F17QUgML1fYe1VPZzTV6EMCnYktE github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= -github.com/sashabaranov/go-openai v1.24.1 h1:DWK95XViNb+agQtuzsn+FyHhn3HQJ7Va8z04DQDJ1MI= -github.com/sashabaranov/go-openai v1.24.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/sashabaranov/go-openai v1.25.0 h1:3h3DtJ55zQJqc+BR4y/iTcPhLk4pewJpyO+MXW2RdW0= github.com/sashabaranov/go-openai v1.25.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo=