From d82db1810e34f10f2d927f182bcefae80235aa58 Mon Sep 17 00:00:00 2001 From: Martin Buhr Date: Fri, 13 Sep 2024 09:20:33 +1200 Subject: [PATCH] fixes #1010 --- llms/anthropic/anthropicllm.go | 112 ++++++++++++++++++++++----------- 1 file changed, 75 insertions(+), 37 deletions(-) diff --git a/llms/anthropic/anthropicllm.go b/llms/anthropic/anthropicllm.go index 2d236ebe5..4494ce295 100644 --- a/llms/anthropic/anthropicllm.go +++ b/llms/anthropic/anthropicllm.go @@ -272,35 +272,52 @@ func handleHumanMessage(msg llms.MessageContent) (anthropicclient.ChatMessage, e return anthropicclient.ChatMessage{}, fmt.Errorf("anthropic: %w for human message", ErrInvalidContentType) } -func handleAIMessage(msg llms.MessageContent) (anthropicclient.ChatMessage, error) { - if toolCall, ok := msg.Parts[0].(llms.ToolCall); ok { - var inputStruct map[string]interface{} - err := json.Unmarshal([]byte(toolCall.FunctionCall.Arguments), &inputStruct) - if err != nil { - return anthropicclient.ChatMessage{}, fmt.Errorf("anthropic: failed to unmarshal tool call arguments: %w", err) - } - toolUse := anthropicclient.ToolUseContent{ - Type: "tool_use", - ID: toolCall.ID, - Name: toolCall.FunctionCall.Name, - Input: inputStruct, - } +func getTextPart(part llms.TextContent) *anthropicclient.TextContent { + return &anthropicclient.TextContent{ + Type: "text", + Text: part.Text, + } +} - return anthropicclient.ChatMessage{ - Role: RoleAssistant, - Content: []anthropicclient.Content{toolUse}, - }, nil +func getToolPart(part llms.ToolCall) (*anthropicclient.ToolUseContent, error) { + var inputStruct map[string]interface{} + err := json.Unmarshal([]byte(part.FunctionCall.Arguments), &inputStruct) + if err != nil { + return nil, fmt.Errorf("anthropic: failed to unmarshal tool call arguments: %w", err) } - if textContent, ok := msg.Parts[0].(llms.TextContent); ok { - return anthropicclient.ChatMessage{ - Role: RoleAssistant, - Content: []anthropicclient.Content{&anthropicclient.TextContent{ - Type: "text", - Text: textContent.Text, - }}, - }, nil + return &anthropicclient.ToolUseContent{ + Type: "tool_use", + ID: part.ID, + Name: part.FunctionCall.Name, + Input: inputStruct, + }, nil +} + +func handleAIMessage(msg llms.MessageContent) (anthropicclient.ChatMessage, error) { + cm := anthropicclient.ChatMessage{ + Role: RoleAssistant, + Content: make([]anthropicclient.Content, 0), } - return anthropicclient.ChatMessage{}, fmt.Errorf("anthropic: %w for AI message", ErrInvalidContentType) + + contentArr := make([]anthropicclient.Content, 0) + + for _, part := range msg.Parts { + switch part.(type) { + case llms.TextContent: + contentArr = append(contentArr, getTextPart(part.(llms.TextContent))) + case llms.ToolCall: + tp, err := getToolPart(part.(llms.ToolCall)) + if err != nil { + return anthropicclient.ChatMessage{}, fmt.Errorf("anthropic: %w for AI message %T", err, part) + } + contentArr = append(contentArr, tp) + default: + return anthropicclient.ChatMessage{}, fmt.Errorf("anthropic: %w for AI message %T", ErrInvalidContentType, part) + } + } + + cm.Content = contentArr + return cm, nil } type ToolResult struct { @@ -309,18 +326,39 @@ type ToolResult struct { Content string `json:"content"` } +func getToolResponse(part llms.ToolCallResponse) (*anthropicclient.ToolResultContent, error) { + return &anthropicclient.ToolResultContent{ + Type: "tool_result", + ToolUseID: part.ToolCallID, + Content: part.Content, + }, nil + +} + func handleToolMessage(msg llms.MessageContent) (anthropicclient.ChatMessage, error) { - if toolCallResponse, ok := msg.Parts[0].(llms.ToolCallResponse); ok { - toolContent := anthropicclient.ToolResultContent{ - Type: "tool_result", - ToolUseID: toolCallResponse.ToolCallID, - Content: toolCallResponse.Content, - } + cm := anthropicclient.ChatMessage{ + Role: RoleUser, + Content: make([]anthropicclient.Content, 0), + } - return anthropicclient.ChatMessage{ - Role: RoleUser, - Content: []anthropicclient.Content{toolContent}, - }, nil + contentArr := make([]anthropicclient.Content, 0) + + for _, part := range msg.Parts { + switch part.(type) { + case llms.TextContent: + contentArr = append(contentArr, getTextPart(part.(llms.TextContent))) + case llms.ToolCallResponse: + tp, err := getToolResponse(part.(llms.ToolCallResponse)) + if err != nil { + return anthropicclient.ChatMessage{}, fmt.Errorf("anthropic: %w for tool part response message %T", err, part) + } + contentArr = append(contentArr, tp) + default: + return anthropicclient.ChatMessage{}, fmt.Errorf("anthropic: %w for AI message %T", ErrInvalidContentType, part) + } } - return anthropicclient.ChatMessage{}, fmt.Errorf("anthropic: %w for tool message", ErrInvalidContentType) + + cm.Content = contentArr + + return cm, nil }