From c872d96aad6a52fc5c1603be2441c1ba60c69ef6 Mon Sep 17 00:00:00 2001 From: zijiren233 Date: Mon, 23 Dec 2024 00:08:54 +0800 Subject: [PATCH] feat: support google thinking --- service/aiproxy/common/conv/any.go | 9 +-- .../aiproxy/relay/adaptor/anthropic/main.go | 7 +- .../aiproxy/relay/adaptor/aws/llama3/main.go | 3 +- service/aiproxy/relay/adaptor/baidu/main.go | 2 +- service/aiproxy/relay/adaptor/cohere/main.go | 3 +- service/aiproxy/relay/adaptor/coze/main.go | 3 +- .../aiproxy/relay/adaptor/gemini/adaptor.go | 8 ++- .../aiproxy/relay/adaptor/gemini/constants.go | 5 ++ service/aiproxy/relay/adaptor/gemini/main.go | 65 ++++++++++++++----- service/aiproxy/relay/adaptor/ollama/main.go | 2 +- service/aiproxy/relay/adaptor/openai/model.go | 4 +- 11 files changed, 77 insertions(+), 34 deletions(-) diff --git a/service/aiproxy/common/conv/any.go b/service/aiproxy/common/conv/any.go index ed6de0d1c12..d5e3bc037fd 100644 --- a/service/aiproxy/common/conv/any.go +++ b/service/aiproxy/common/conv/any.go @@ -9,15 +9,10 @@ func AsString(v any) string { // The change of bytes will cause the change of string synchronously func BytesToString(b []byte) string { - return *(*string)(unsafe.Pointer(&b)) + return unsafe.String(unsafe.SliceData(b), len(b)) } // If string is readonly, modifying bytes will cause panic func StringToBytes(s string) []byte { - return *(*[]byte)(unsafe.Pointer( - &struct { - string - Cap int - }{s, len(s)}, - )) + return unsafe.Slice(unsafe.StringData(s), len(s)) } diff --git a/service/aiproxy/relay/adaptor/anthropic/main.go b/service/aiproxy/relay/adaptor/anthropic/main.go index b19cc12eb30..169a6e0df0d 100644 --- a/service/aiproxy/relay/adaptor/anthropic/main.go +++ b/service/aiproxy/relay/adaptor/anthropic/main.go @@ -15,6 +15,7 @@ import ( "github.com/labring/sealos/service/aiproxy/common" "github.com/labring/sealos/service/aiproxy/common/image" "github.com/labring/sealos/service/aiproxy/relay/adaptor/openai" + "github.com/labring/sealos/service/aiproxy/relay/constant" "github.com/labring/sealos/service/aiproxy/relay/meta" "github.com/labring/sealos/service/aiproxy/relay/model" ) @@ -26,10 +27,8 @@ func stopReasonClaude2OpenAI(reason *string) string { return "" } switch *reason { - case "end_turn": - return "stop" - case "stop_sequence": - return "stop" + case "end_turn", "stop_sequence": + return constant.StopFinishReason case "max_tokens": return "length" case toolUseType: diff --git a/service/aiproxy/relay/adaptor/aws/llama3/main.go b/service/aiproxy/relay/adaptor/aws/llama3/main.go index 8648788943f..7ec0a55001f 100644 --- a/service/aiproxy/relay/adaptor/aws/llama3/main.go +++ b/service/aiproxy/relay/adaptor/aws/llama3/main.go @@ -21,6 +21,7 @@ import ( "github.com/labring/sealos/service/aiproxy/model" "github.com/labring/sealos/service/aiproxy/relay/adaptor/aws/utils" "github.com/labring/sealos/service/aiproxy/relay/adaptor/openai" + "github.com/labring/sealos/service/aiproxy/relay/constant" "github.com/labring/sealos/service/aiproxy/relay/meta" relaymodel "github.com/labring/sealos/service/aiproxy/relay/model" "github.com/labring/sealos/service/aiproxy/relay/relaymode" @@ -214,7 +215,7 @@ func StreamHandler(meta *meta.Meta, c *gin.Context) (*relaymodel.ErrorWithStatus if llamaResp.PromptTokenCount > 0 { usage.PromptTokens = llamaResp.PromptTokenCount } - if llamaResp.StopReason == "stop" { + if llamaResp.StopReason == constant.StopFinishReason { usage.CompletionTokens = llamaResp.GenerationTokenCount usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens } diff --git a/service/aiproxy/relay/adaptor/baidu/main.go b/service/aiproxy/relay/adaptor/baidu/main.go index f68ecbc5019..c608803dfa6 100644 --- a/service/aiproxy/relay/adaptor/baidu/main.go +++ b/service/aiproxy/relay/adaptor/baidu/main.go @@ -93,7 +93,7 @@ func responseBaidu2OpenAI(response *ChatResponse) *openai.TextResponse { Role: "assistant", Content: response.Result, }, - FinishReason: "stop", + FinishReason: constant.StopFinishReason, } fullTextResponse := openai.TextResponse{ ID: response.ID, diff --git a/service/aiproxy/relay/adaptor/cohere/main.go b/service/aiproxy/relay/adaptor/cohere/main.go index d4b11148313..455e86b263e 100644 --- a/service/aiproxy/relay/adaptor/cohere/main.go +++ b/service/aiproxy/relay/adaptor/cohere/main.go @@ -14,6 +14,7 @@ import ( "github.com/labring/sealos/service/aiproxy/common/render" "github.com/labring/sealos/service/aiproxy/middleware" "github.com/labring/sealos/service/aiproxy/relay/adaptor/openai" + "github.com/labring/sealos/service/aiproxy/relay/constant" "github.com/labring/sealos/service/aiproxy/relay/model" ) @@ -25,7 +26,7 @@ func stopReasonCohere2OpenAI(reason *string) string { } switch *reason { case "COMPLETE": - return "stop" + return constant.StopFinishReason default: return *reason } diff --git a/service/aiproxy/relay/adaptor/coze/main.go b/service/aiproxy/relay/adaptor/coze/main.go index 296769d2d8a..c0108e57b62 100644 --- a/service/aiproxy/relay/adaptor/coze/main.go +++ b/service/aiproxy/relay/adaptor/coze/main.go @@ -14,6 +14,7 @@ import ( "github.com/labring/sealos/service/aiproxy/middleware" "github.com/labring/sealos/service/aiproxy/relay/adaptor/coze/constant/messagetype" "github.com/labring/sealos/service/aiproxy/relay/adaptor/openai" + "github.com/labring/sealos/service/aiproxy/relay/constant" "github.com/labring/sealos/service/aiproxy/relay/model" ) @@ -93,7 +94,7 @@ func ResponseCoze2OpenAI(cozeResponse *Response) *openai.TextResponse { Content: responseText, Name: nil, }, - FinishReason: "stop", + FinishReason: constant.StopFinishReason, } fullTextResponse := openai.TextResponse{ ID: "chatcmpl-" + cozeResponse.ConversationID, diff --git a/service/aiproxy/relay/adaptor/gemini/adaptor.go b/service/aiproxy/relay/adaptor/gemini/adaptor.go index c2964c1809b..969be992d82 100644 --- a/service/aiproxy/relay/adaptor/gemini/adaptor.go +++ b/service/aiproxy/relay/adaptor/gemini/adaptor.go @@ -19,12 +19,18 @@ type Adaptor struct{} const baseURL = "https://generativelanguage.googleapis.com" +var v1ModelMap = map[string]struct{}{} + func getRequestURL(meta *meta.Meta, action string) string { u := meta.Channel.BaseURL if u == "" { u = baseURL } - return fmt.Sprintf("%s/%s/models/%s:%s", u, "v1beta", meta.ActualModelName, action) + version := "v1beta" + if _, ok := v1ModelMap[meta.ActualModelName]; ok { + version = "v1" + } + return fmt.Sprintf("%s/%s/models/%s:%s", u, version, meta.ActualModelName, action) } func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { diff --git a/service/aiproxy/relay/adaptor/gemini/constants.go b/service/aiproxy/relay/adaptor/gemini/constants.go index fc6b1354a5b..5e89798b914 100644 --- a/service/aiproxy/relay/adaptor/gemini/constants.go +++ b/service/aiproxy/relay/adaptor/gemini/constants.go @@ -28,6 +28,11 @@ var ModelList = []*model.ModelConfig{ Type: relaymode.ChatCompletions, Owner: model.ModelOwnerGoogle, }, + { + Model: "gemini-2.0-flash-thinking-exp", + Type: relaymode.ChatCompletions, + Owner: model.ModelOwnerGoogle, + }, { Model: "text-embedding-004", diff --git a/service/aiproxy/relay/adaptor/gemini/main.go b/service/aiproxy/relay/adaptor/gemini/main.go index 47fcc86018b..e4732a3430b 100644 --- a/service/aiproxy/relay/adaptor/gemini/main.go +++ b/service/aiproxy/relay/adaptor/gemini/main.go @@ -50,6 +50,7 @@ func buildSafetySettings() []ChatSafetySettings { {Category: "HARM_CATEGORY_HATE_SPEECH", Threshold: safetySetting}, {Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", Threshold: safetySetting}, {Category: "HARM_CATEGORY_DANGEROUS_CONTENT", Threshold: safetySetting}, + {Category: "HARM_CATEGORY_CIVIC_INTEGRITY", Threshold: safetySetting}, } } @@ -237,10 +238,13 @@ func (g *ChatResponse) GetResponseText() string { if g == nil { return "" } - if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 { - return g.Candidates[0].Content.Parts[0].Text + builder := strings.Builder{} + for _, candidate := range g.Candidates { + for _, part := range candidate.Content.Parts { + builder.WriteString(part.Text) + } } - return "" + return builder.String() } type ChatCandidate struct { @@ -283,9 +287,10 @@ func getToolCalls(candidate *ChatCandidate) []*model.Tool { return toolCalls } -func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse { +func responseGeminiChat2OpenAI(meta *meta.Meta, response *ChatResponse) *openai.TextResponse { fullTextResponse := openai.TextResponse{ ID: "chatcmpl-" + random.GetUUID(), + Model: meta.OriginModelName, Object: "chat.completion", Created: time.Now().Unix(), Choices: make([]*openai.TextResponseChoice, 0, len(response.Candidates)), @@ -302,7 +307,14 @@ func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse { if candidate.Content.Parts[0].FunctionCall != nil { choice.Message.ToolCalls = getToolCalls(candidate) } else { - choice.Message.Content = candidate.Content.Parts[0].Text + builder := strings.Builder{} + for i, part := range candidate.Content.Parts { + if i > 0 { + builder.WriteString("\n") + } + builder.WriteString(part.Text) + } + choice.Message.Content = builder.String() } } else { choice.Message.Content = "" @@ -314,16 +326,37 @@ func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse { } func streamResponseGeminiChat2OpenAI(meta *meta.Meta, geminiResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { - var choice openai.ChatCompletionsStreamResponseChoice - choice.Delta.Content = geminiResponse.GetResponseText() - // choice.FinishReason = &constant.StopFinishReason - var response openai.ChatCompletionsStreamResponse - response.ID = "chatcmpl-" + random.GetUUID() - response.Created = time.Now().Unix() - response.Object = "chat.completion.chunk" - response.Model = meta.OriginModelName - response.Choices = []*openai.ChatCompletionsStreamResponseChoice{&choice} - return &response + response := &openai.ChatCompletionsStreamResponse{ + ID: "chatcmpl-" + random.GetUUID(), + Created: time.Now().Unix(), + Model: meta.OriginModelName, + Object: "chat.completion.chunk", + Choices: make([]*openai.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates)), + } + for i, candidate := range geminiResponse.Candidates { + choice := openai.ChatCompletionsStreamResponseChoice{ + Index: i, + } + if len(candidate.Content.Parts) > 0 { + if candidate.Content.Parts[0].FunctionCall != nil { + choice.Delta.ToolCalls = getToolCalls(candidate) + } else { + builder := strings.Builder{} + for i, part := range candidate.Content.Parts { + if i > 0 { + builder.WriteString("\n") + } + builder.WriteString(part.Text) + } + choice.Delta.Content = builder.String() + } + } else { + choice.Delta.Content = "" + choice.FinishReason = &candidate.FinishReason + } + response.Choices = append(response.Choices, &choice) + } + return response } func StreamHandler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model.Usage, *model.ErrorWithStatusCode) { @@ -405,7 +438,7 @@ func Handler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model.Usage if len(geminiResponse.Candidates) == 0 { return nil, openai.ErrorWrapperWithMessage("No candidates returned", "gemini_error", resp.StatusCode) } - fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse) + fullTextResponse := responseGeminiChat2OpenAI(meta, &geminiResponse) fullTextResponse.Model = meta.OriginModelName respContent := []ChatContent{} for _, candidate := range geminiResponse.Candidates { diff --git a/service/aiproxy/relay/adaptor/ollama/main.go b/service/aiproxy/relay/adaptor/ollama/main.go index 2452afa038a..008a7fb32c1 100644 --- a/service/aiproxy/relay/adaptor/ollama/main.go +++ b/service/aiproxy/relay/adaptor/ollama/main.go @@ -84,7 +84,7 @@ func responseOllama2OpenAI(response *ChatResponse) *openai.TextResponse { }, } if response.Done { - choice.FinishReason = "stop" + choice.FinishReason = constant.StopFinishReason } fullTextResponse := openai.TextResponse{ ID: "chatcmpl-" + random.GetUUID(), diff --git a/service/aiproxy/relay/adaptor/openai/model.go b/service/aiproxy/relay/adaptor/openai/model.go index 6c101b93398..b83898c2daa 100644 --- a/service/aiproxy/relay/adaptor/openai/model.go +++ b/service/aiproxy/relay/adaptor/openai/model.go @@ -1,6 +1,8 @@ package openai -import "github.com/labring/sealos/service/aiproxy/relay/model" +import ( + "github.com/labring/sealos/service/aiproxy/relay/model" +) type TextContent struct { Type string `json:"type,omitempty"`