Skip to content

Commit

Permalink
Merge pull request #674 from Yan-Zero/main
Browse files Browse the repository at this point in the history
fix: Gemini 函数调用的文本转义,以及其他文件类型的 Base64 支持
  • Loading branch information
Calcium-Ion authored Dec 29, 2024
2 parents 65d1cde + 2a15dfc commit a1b864b
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 40 deletions.
119 changes: 80 additions & 39 deletions relay/channel/gemini/relay-gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
"unicode/utf8"

"github.com/gin-gonic/gin"
)
Expand Down Expand Up @@ -203,13 +204,13 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
},
})
} else {
_, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
format, base64String, err := service.DecodeBase64FileData(part.ImageUrl.(dto.MessageImageUrl).Url)
if err != nil {
return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
}
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: "image/" + format,
MimeType: format,
Data: base64String,
},
})
Expand Down Expand Up @@ -279,57 +280,97 @@ func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interfac
return v
}

// func (g *GeminiChatResponse) GetResponseText() string {
// if g == nil {
// return ""
// }
// if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 {
// return g.Candidates[0].Content.Parts[0].Text
// }
// return ""
// }
func unescapeString(s string) (string, error) {
var result []rune
escaped := false
i := 0

for i < len(s) {
r, size := utf8.DecodeRuneInString(s[i:]) // 正确解码UTF-8字符
if r == utf8.RuneError {
return "", fmt.Errorf("invalid UTF-8 encoding")
}

if escaped {
// 如果是转义符后的字符,检查其类型
switch r {
case '"':
result = append(result, '"')
case '\\':
result = append(result, '\\')
case '/':
result = append(result, '/')
case 'b':
result = append(result, '\b')
case 'f':
result = append(result, '\f')
case 'n':
result = append(result, '\n')
case 'r':
result = append(result, '\r')
case 't':
result = append(result, '\t')
case '\'':
result = append(result, '\'')
default:
// 如果遇到一个非法的转义字符,直接按原样输出
result = append(result, '\\', r)
}
escaped = false
} else {
if r == '\\' {
escaped = true // 记录反斜杠作为转义符
} else {
result = append(result, r)
}
}
i += size // 移动到下一个字符
}

return string(result), nil
}
func unescapeMapOrSlice(data interface{}) interface{} {
switch v := data.(type) {
case map[string]interface{}:
for k, val := range v {
v[k] = unescapeMapOrSlice(val)
}
case []interface{}:
for i, val := range v {
v[i] = unescapeMapOrSlice(val)
}
case string:
if unescaped, err := unescapeString(v); err != nil {
return v
} else {
return unescaped
}
}
return data
}

func getToolCall(item *GeminiPart) *dto.ToolCall {
argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
var argsBytes []byte
var err error
if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok {
argsBytes, err = json.Marshal(unescapeMapOrSlice(result))
} else {
argsBytes, err = json.Marshal(item.FunctionCall.Arguments)
}

if err != nil {
//common.SysError("getToolCall failed: " + err.Error())
return nil
}
return &dto.ToolCall{
ID: fmt.Sprintf("call_%s", common.GetUUID()),
Type: "function",
Function: dto.FunctionCall{
// 不好评价,得去转义一下反斜杠,Gemini 的特性好像是,Google 返回的时候本身就会转义“\”
Arguments: strings.ReplaceAll(string(argsBytes), "\\\\", "\\"),
Arguments: string(argsBytes),
Name: item.FunctionCall.FunctionName,
},
}
}

// func getToolCalls(candidate *GeminiChatCandidate, index int) []dto.ToolCall {
// var toolCalls []dto.ToolCall

// item := candidate.Content.Parts[index]
// if item.FunctionCall == nil {
// return toolCalls
// }
// argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
// if err != nil {
// //common.SysError("getToolCalls failed: " + err.Error())
// return toolCalls
// }
// toolCall := dto.ToolCall{
// ID: fmt.Sprintf("call_%s", common.GetUUID()),
// Type: "function",
// Function: dto.FunctionCall{
// Arguments: string(argsBytes),
// Name: item.FunctionCall.FunctionName,
// },
// }
// toolCalls = append(toolCalls, toolCall)
// return toolCalls
// }

func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
fullTextResponse := dto.OpenAITextResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Expand Down
28 changes: 27 additions & 1 deletion service/image.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@ import (
"encoding/base64"
"errors"
"fmt"
"golang.org/x/image/webp"
"image"
"io"
"one-api/common"
"strings"

"golang.org/x/image/webp"
)

func DecodeBase64ImageData(base64String string) (image.Config, string, string, error) {
Expand All @@ -31,6 +32,31 @@ func DecodeBase64ImageData(base64String string) (image.Config, string, string, e
return config, format, base64String, err
}

func DecodeBase64FileData(base64String string) (string, string, error) {
var mimeType string
var idx int
idx = strings.Index(base64String, ",")
if idx == -1 {
_, file_type, base64, err := DecodeBase64ImageData(base64String)
return "image/" + file_type, base64, err
}
mimeType = base64String[:idx]
base64String = base64String[idx+1:]
idx = strings.Index(mimeType, ";")
if idx == -1 {
_, file_type, base64, err := DecodeBase64ImageData(base64String)
return "image/" + file_type, base64, err
}
mimeType = mimeType[:idx]
idx = strings.Index(mimeType, ":")
if idx == -1 {
_, file_type, base64, err := DecodeBase64ImageData(base64String)
return "image/" + file_type, base64, err
}
mimeType = mimeType[idx+1:]
return mimeType, base64String, nil
}

// GetImageFromUrl 获取图片的类型和base64编码的数据
func GetImageFromUrl(url string) (mimeType string, data string, err error) {
resp, err := DoDownloadRequest(url)
Expand Down

0 comments on commit a1b864b

Please sign in to comment.