Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update openai function calling to new version #208

Merged
merged 5 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ require (
github.com/pkg/errors v0.9.1
github.com/prometheus/client_golang v1.19.1
github.com/r3labs/sse/v2 v2.10.0
github.com/sashabaranov/go-openai v1.24.0
github.com/sashabaranov/go-openai v1.25.0
github.com/sirupsen/logrus v1.9.3
github.com/stretchr/testify v1.8.4
golang.org/x/text v0.16.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,8 @@ github.com/r3labs/sse/v2 v2.10.0/go.mod h1:Igau6Whc+F17QUgML1fYe1VPZzTV6EMCnYktE
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g=
github.com/sashabaranov/go-openai v1.24.0 h1:4H4Pg8Bl2RH/YSnU8DYumZbuHnnkfioor/dtNlB20D4=
github.com/sashabaranov/go-openai v1.24.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sashabaranov/go-openai v1.25.0 h1:3h3DtJ55zQJqc+BR4y/iTcPhLk4pewJpyO+MXW2RdW0=
github.com/sashabaranov/go-openai v1.25.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo=
github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY=
github.com/shurcooL/events v0.0.0-20181021180414-410e4ca65f48/go.mod h1:5u70Mqkb5O5cxEA8nxTsgrgLehJeAw6Oc4Ab1c/P1HM=
Expand Down
112 changes: 76 additions & 36 deletions server/ai/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,12 @@ func New(llmService ai.ServiceConfig, metricsService metrics.LLMetrics) *OpenAI

func modifyCompletionRequestWithConversation(request openaiClient.ChatCompletionRequest, conversation ai.BotConversation) openaiClient.ChatCompletionRequest {
request.Messages = postsToChatCompletionMessages(conversation.Posts)
request.Functions = toolsToFunctionDefinitions(conversation.Tools.GetTools()) //nolint:all
request.Tools = toolsToOpenAITools(conversation.Tools.GetTools())
return request
}

func toolsToFunctionDefinitions(tools []ai.Tool) []openaiClient.FunctionDefinition {
result := make([]openaiClient.FunctionDefinition, 0, len(tools))
func toolsToOpenAITools(tools []ai.Tool) []openaiClient.Tool {
result := make([]openaiClient.Tool, 0, len(tools))

schemaMaker := jsonschema.Reflector{
Anonymous: true,
Expand All @@ -102,10 +102,13 @@ func toolsToFunctionDefinitions(tools []ai.Tool) []openaiClient.FunctionDefiniti

for _, tool := range tools {
schema := schemaMaker.Reflect(tool.Schema)
result = append(result, openaiClient.FunctionDefinition{
Name: tool.Name,
Description: tool.Description,
Parameters: schema,
result = append(result, openaiClient.Tool{
Type: openaiClient.ToolTypeFunction,
Function: &openaiClient.FunctionDefinition{
Name: tool.Name,
Description: tool.Description,
Parameters: schema,
},
})
}

Expand Down Expand Up @@ -183,18 +186,10 @@ func createFunctionArrgmentResolver(jsonArgs string) ai.ToolArgumentGetter {
}
}

func (s *OpenAI) handleStreamFunctionCall(request openaiClient.ChatCompletionRequest, conversation ai.BotConversation, name, arguments string) (openaiClient.ChatCompletionRequest, error) {
toolResult, err := conversation.Tools.ResolveTool(name, createFunctionArrgmentResolver(arguments), conversation.Context)
if err != nil {
fmt.Println("Error resolving function: ", err)
}
request.Messages = append(request.Messages, openaiClient.ChatCompletionMessage{
Role: openaiClient.ChatMessageRoleFunction,
Name: name,
Content: toolResult,
})

return request, nil
type ToolBufferElement struct {
id strings.Builder
name strings.Builder
args strings.Builder
}

func (s *OpenAI) streamResultToChannels(request openaiClient.ChatCompletionRequest, conversation ai.BotConversation, output chan<- string, errChan chan<- error) {
Expand Down Expand Up @@ -236,9 +231,8 @@ func (s *OpenAI) streamResultToChannels(request openaiClient.ChatCompletionReque

defer stream.Close()

// Buffering in the case of a function call.
functionName := strings.Builder{}
functionArguments := strings.Builder{}
// Buffering in the case of tool use
var toolsBuffer map[int]*ToolBufferElement
for {
response, err := stream.Recv()
if errors.Is(err, io.EOF) {
Expand Down Expand Up @@ -266,11 +260,11 @@ func (s *OpenAI) streamResultToChannels(request openaiClient.ChatCompletionReque
// Not done yet, keep going
case openaiClient.FinishReasonStop:
return
case openaiClient.FinishReasonFunctionCall:
case openaiClient.FinishReasonToolCalls:
// Verify OpenAI functions are not recursing too deep.
numFunctionCalls := 0
for i := len(request.Messages) - 1; i >= 0; i-- {
if request.Messages[i].Role == openaiClient.ChatMessageRoleFunction {
if request.Messages[i].Role == openaiClient.ChatMessageRoleTool {
numFunctionCalls++
} else {
break
Expand All @@ -281,26 +275,72 @@ func (s *OpenAI) streamResultToChannels(request openaiClient.ChatCompletionReque
return
}

// Call ourselves again with the result of the function call
recursiveRequest, err := s.handleStreamFunctionCall(request, conversation, functionName.String(), functionArguments.String())
if err != nil {
errChan <- err
return
// Transfer the buffered tools into tool calls
tools := []openaiClient.ToolCall{}
for i, tool := range toolsBuffer {
name := tool.name.String()
arguments := tool.args.String()
toolID := tool.id.String()
num := i
tools = append(tools, openaiClient.ToolCall{
Function: openaiClient.FunctionCall{
Name: name,
Arguments: arguments,
},
ID: toolID,
Index: &num,
Type: openaiClient.ToolTypeFunction,
})
}
s.streamResultToChannels(recursiveRequest, conversation, output, errChan)

// Add the tool calls to the request
request.Messages = append(request.Messages, openaiClient.ChatCompletionMessage{
Role: openaiClient.ChatMessageRoleAssistant,
ToolCalls: tools,
})

// Resolve the tools and create messages for each
for _, tool := range tools {
name := tool.Function.Name
arguments := tool.Function.Arguments
toolID := tool.ID
toolResult, err := conversation.Tools.ResolveTool(name, createFunctionArrgmentResolver(arguments), conversation.Context)
if err != nil {
fmt.Printf("Error resolving function %s: %s", name, err)
}
request.Messages = append(request.Messages, openaiClient.ChatCompletionMessage{
Role: openaiClient.ChatMessageRoleTool,
Name: name,
Content: toolResult,
ToolCallID: toolID,
})
}

// Call ourselves again with the result of the function call
s.streamResultToChannels(request, conversation, output, errChan)
return
default:
fmt.Printf("Unknown finish reason: %s", response.Choices[0].FinishReason)
return
}

// Keep track of any function call received
if response.Choices[0].Delta.FunctionCall != nil {
if response.Choices[0].Delta.FunctionCall.Name != "" {
functionName.WriteString(response.Choices[0].Delta.FunctionCall.Name)
delta := response.Choices[0].Delta
numTools := len(delta.ToolCalls)
if numTools != 0 {
if toolsBuffer == nil {
toolsBuffer = make(map[int]*ToolBufferElement)
}
if response.Choices[0].Delta.FunctionCall.Arguments != "" {
functionArguments.WriteString(response.Choices[0].Delta.FunctionCall.Arguments)
for _, toolCall := range delta.ToolCalls {
if toolCall.Index == nil {
continue
}
toolIndex := *toolCall.Index
if toolsBuffer[toolIndex] == nil {
toolsBuffer[toolIndex] = &ToolBufferElement{}
}
toolsBuffer[toolIndex].name.WriteString(toolCall.Function.Name)
toolsBuffer[toolIndex].args.WriteString(toolCall.Function.Arguments)
toolsBuffer[toolIndex].id.WriteString(toolCall.ID)
}
}

Expand Down
20 changes: 5 additions & 15 deletions server/ai/prompts.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,8 @@ import (
"errors"
)

type BuiltInToolsFunc func(isDM bool) []Tool

type Prompts struct {
templates *template.Template
getBuiltInTools BuiltInToolsFunc
templates *template.Template
}

const PromptExtension = "tmpl"
Expand All @@ -40,33 +37,26 @@ const (
PromptFindOpenQuestionsSince = "find_open_questions_since"
)

func NewPrompts(input fs.FS, getBuiltInTools BuiltInToolsFunc) (*Prompts, error) {
func NewPrompts(input fs.FS) (*Prompts, error) {
templates, err := template.ParseFS(input, "ai/prompts/*")
if err != nil {
return nil, fmt.Errorf("unable to parse prompt templates: %w", err)
}

return &Prompts{
templates: templates,
getBuiltInTools: getBuiltInTools,
templates: templates,
}, nil
}

func withPromptExtension(filename string) string {
return filename + "." + PromptExtension
}

func (p *Prompts) getDefaultTools(isDMWithBot bool) ToolStore {
tools := NewToolStore()
tools.AddTools(p.getBuiltInTools(isDMWithBot))
return tools
}

func (p *Prompts) ChatCompletion(templateName string, context ConversationContext) (BotConversation, error) {
func (p *Prompts) ChatCompletion(templateName string, context ConversationContext, tools ToolStore) (BotConversation, error) {
conversation := BotConversation{
Posts: []Post{},
Context: context,
Tools: p.getDefaultTools(context.IsDMWithBot()),
Tools: tools,
}

template := p.templates.Lookup(withPromptExtension(templateName))
Expand Down
44 changes: 40 additions & 4 deletions server/ai/tools.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ai

import (
"encoding/json"
"errors"
)

Expand All @@ -14,12 +15,28 @@ type Tool struct {
type ToolArgumentGetter func(args any) error

type ToolStore struct {
tools map[string]Tool
tools map[string]Tool
log TraceLog
doTrace bool
}

func NewToolStore() ToolStore {
type TraceLog interface {
Info(message string, keyValuePairs ...any)
}

func NewNoTools() ToolStore {
return ToolStore{
tools: make(map[string]Tool),
log: nil,
doTrace: false,
}
}

func NewToolStore(log TraceLog, doTrace bool) ToolStore {
return ToolStore{
tools: make(map[string]Tool),
tools: make(map[string]Tool),
log: log,
doTrace: doTrace,
}
}

Expand All @@ -32,9 +49,12 @@ func (s *ToolStore) AddTools(tools []Tool) {
func (s *ToolStore) ResolveTool(name string, argsGetter ToolArgumentGetter, context ConversationContext) (string, error) {
tool, ok := s.tools[name]
if !ok {
s.TraceUnknown(name, argsGetter)
return "", errors.New("unknown tool " + name)
}
return tool.Resolver(context, argsGetter)
results, err := tool.Resolver(context, argsGetter)
s.TraceResolved(name, argsGetter, results)
return results, err
}

func (s *ToolStore) GetTools() []Tool {
Expand All @@ -44,3 +64,19 @@ func (s *ToolStore) GetTools() []Tool {
}
return result
}

func (s *ToolStore) TraceUnknown(name string, argsGetter ToolArgumentGetter) {
if s.log != nil && s.doTrace {
var raw json.RawMessage
argsGetter(raw)
s.log.Info("unknown tool called", "name", name, "args", string(raw))
}
}

func (s *ToolStore) TraceResolved(name string, argsGetter ToolArgumentGetter, result string) {
if s.log != nil && s.doTrace {
var raw json.RawMessage
argsGetter(raw)
s.log.Info("tool resolved", "name", name, "args", string(raw), "result", result)
}
}
2 changes: 1 addition & 1 deletion server/api_channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func (p *Plugin) handleSince(c *gin.Context) {
return
}

prompt, err := p.prompts.ChatCompletion(promptPreset, context)
prompt, err := p.prompts.ChatCompletion(promptPreset, context, p.getDefaultToolsStore(context.IsDMWithBot()))
if err != nil {
c.AbortWithError(http.StatusInternalServerError, err)
return
Expand Down
2 changes: 1 addition & 1 deletion server/api_post.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (p *Plugin) handleReact(c *gin.Context) {

context := p.MakeConversationContext(bot, user, channel, post)
context.PromptParameters = map[string]string{"Message": post.Message}
prompt, err := p.prompts.ChatCompletion(ai.PromptEmojiSelect, context)
prompt, err := p.prompts.ChatCompletion(ai.PromptEmojiSelect, context, ai.NewNoTools())
if err != nil {
c.AbortWithError(http.StatusInternalServerError, err)
return
Expand Down
6 changes: 6 additions & 0 deletions server/built_in_tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -431,3 +431,9 @@ func (p *Plugin) getBuiltInTools(isDM bool) []ai.Tool {

return builtInTools
}

func (p *Plugin) getDefaultToolsStore(isDM bool) ai.ToolStore {
store := ai.NewToolStore(&p.pluginAPI.Log, p.getConfiguration().EnableLLMTrace)
store.AddTools(p.getBuiltInTools(isDM))
return store
}
4 changes: 2 additions & 2 deletions server/meeting_summarization.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ func (p *Plugin) summarizeTranscription(bot *Bot, transcription *subtitles.Subti
p.pluginAPI.Log.Debug("Split into chunks", "chunks", len(chunks))
for _, chunk := range chunks {
context.PromptParameters = map[string]string{"TranscriptionChunk": chunk}
summarizeChunkPrompt, err := p.prompts.ChatCompletion(ai.PromptSummarizeChunk, context)
summarizeChunkPrompt, err := p.prompts.ChatCompletion(ai.PromptSummarizeChunk, context, p.getDefaultToolsStore(context.IsDMWithBot()))
if err != nil {
return nil, fmt.Errorf("unable to get summarize chunk prompt: %w", err)
}
Expand All @@ -291,7 +291,7 @@ func (p *Plugin) summarizeTranscription(bot *Bot, transcription *subtitles.Subti
}

context.PromptParameters = map[string]string{"Transcription": llmFormattedTranscription, "IsChunked": fmt.Sprintf("%t", isChunked)}
summaryPrompt, err := p.prompts.ChatCompletion(ai.PromptMeetingSummary, context)
summaryPrompt, err := p.prompts.ChatCompletion(ai.PromptMeetingSummary, context, p.getDefaultToolsStore(context.IsDMWithBot()))
if err != nil {
return nil, fmt.Errorf("unable to get meeting summary prompt: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion server/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func (p *Plugin) OnActivate() error {
}

var err error
p.prompts, err = ai.NewPrompts(promptsFolder, p.getBuiltInTools)
p.prompts, err = ai.NewPrompts(promptsFolder)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion server/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func SetupTestEnvironment(t *testing.T) *TestEnvironment {
}

var promptErr error
p.prompts, promptErr = ai.NewPrompts(promptsFolder, p.getBuiltInTools)
p.prompts, promptErr = ai.NewPrompts(promptsFolder)
require.NoError(t, promptErr)

p.ffmpegPath = ""
Expand Down
Loading