Skip to content

Commit

Permalink
feat: support google thinking
Browse files Browse the repository at this point in the history
  • Loading branch information
zijiren233 committed Dec 22, 2024
1 parent 6fd6677 commit c872d96
Show file tree
Hide file tree
Showing 11 changed files with 77 additions and 34 deletions.
9 changes: 2 additions & 7 deletions service/aiproxy/common/conv/any.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,10 @@ func AsString(v any) string {

// The change of bytes will cause the change of string synchronously
func BytesToString(b []byte) string {
return *(*string)(unsafe.Pointer(&b))
return unsafe.String(unsafe.SliceData(b), len(b))
}

// If string is readonly, modifying bytes will cause panic
func StringToBytes(s string) []byte {
return *(*[]byte)(unsafe.Pointer(
&struct {
string
Cap int
}{s, len(s)},
))
return unsafe.Slice(unsafe.StringData(s), len(s))
}
7 changes: 3 additions & 4 deletions service/aiproxy/relay/adaptor/anthropic/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/labring/sealos/service/aiproxy/common"
"github.com/labring/sealos/service/aiproxy/common/image"
"github.com/labring/sealos/service/aiproxy/relay/adaptor/openai"
"github.com/labring/sealos/service/aiproxy/relay/constant"
"github.com/labring/sealos/service/aiproxy/relay/meta"
"github.com/labring/sealos/service/aiproxy/relay/model"
)
Expand All @@ -26,10 +27,8 @@ func stopReasonClaude2OpenAI(reason *string) string {
return ""
}
switch *reason {
case "end_turn":
return "stop"
case "stop_sequence":
return "stop"
case "end_turn", "stop_sequence":
return constant.StopFinishReason
case "max_tokens":
return "length"
case toolUseType:
Expand Down
3 changes: 2 additions & 1 deletion service/aiproxy/relay/adaptor/aws/llama3/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/labring/sealos/service/aiproxy/model"
"github.com/labring/sealos/service/aiproxy/relay/adaptor/aws/utils"
"github.com/labring/sealos/service/aiproxy/relay/adaptor/openai"
"github.com/labring/sealos/service/aiproxy/relay/constant"
"github.com/labring/sealos/service/aiproxy/relay/meta"
relaymodel "github.com/labring/sealos/service/aiproxy/relay/model"
"github.com/labring/sealos/service/aiproxy/relay/relaymode"
Expand Down Expand Up @@ -214,7 +215,7 @@ func StreamHandler(meta *meta.Meta, c *gin.Context) (*relaymodel.ErrorWithStatus
if llamaResp.PromptTokenCount > 0 {
usage.PromptTokens = llamaResp.PromptTokenCount
}
if llamaResp.StopReason == "stop" {
if llamaResp.StopReason == constant.StopFinishReason {
usage.CompletionTokens = llamaResp.GenerationTokenCount
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
}
Expand Down
2 changes: 1 addition & 1 deletion service/aiproxy/relay/adaptor/baidu/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func responseBaidu2OpenAI(response *ChatResponse) *openai.TextResponse {
Role: "assistant",
Content: response.Result,
},
FinishReason: "stop",
FinishReason: constant.StopFinishReason,
}
fullTextResponse := openai.TextResponse{
ID: response.ID,
Expand Down
3 changes: 2 additions & 1 deletion service/aiproxy/relay/adaptor/cohere/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/labring/sealos/service/aiproxy/common/render"
"github.com/labring/sealos/service/aiproxy/middleware"
"github.com/labring/sealos/service/aiproxy/relay/adaptor/openai"
"github.com/labring/sealos/service/aiproxy/relay/constant"
"github.com/labring/sealos/service/aiproxy/relay/model"
)

Expand All @@ -25,7 +26,7 @@ func stopReasonCohere2OpenAI(reason *string) string {
}
switch *reason {
case "COMPLETE":
return "stop"
return constant.StopFinishReason
default:
return *reason
}
Expand Down
3 changes: 2 additions & 1 deletion service/aiproxy/relay/adaptor/coze/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/labring/sealos/service/aiproxy/middleware"
"github.com/labring/sealos/service/aiproxy/relay/adaptor/coze/constant/messagetype"
"github.com/labring/sealos/service/aiproxy/relay/adaptor/openai"
"github.com/labring/sealos/service/aiproxy/relay/constant"
"github.com/labring/sealos/service/aiproxy/relay/model"
)

Expand Down Expand Up @@ -93,7 +94,7 @@ func ResponseCoze2OpenAI(cozeResponse *Response) *openai.TextResponse {
Content: responseText,
Name: nil,
},
FinishReason: "stop",
FinishReason: constant.StopFinishReason,
}
fullTextResponse := openai.TextResponse{
ID: "chatcmpl-" + cozeResponse.ConversationID,
Expand Down
8 changes: 7 additions & 1 deletion service/aiproxy/relay/adaptor/gemini/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,18 @@ type Adaptor struct{}

const baseURL = "https://generativelanguage.googleapis.com"

var v1ModelMap = map[string]struct{}{}

func getRequestURL(meta *meta.Meta, action string) string {
u := meta.Channel.BaseURL
if u == "" {
u = baseURL
}
return fmt.Sprintf("%s/%s/models/%s:%s", u, "v1beta", meta.ActualModelName, action)
version := "v1beta"
if _, ok := v1ModelMap[meta.ActualModelName]; ok {
version = "v1"
}
return fmt.Sprintf("%s/%s/models/%s:%s", u, version, meta.ActualModelName, action)
}

func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
Expand Down
5 changes: 5 additions & 0 deletions service/aiproxy/relay/adaptor/gemini/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ var ModelList = []*model.ModelConfig{
Type: relaymode.ChatCompletions,
Owner: model.ModelOwnerGoogle,
},
{
Model: "gemini-2.0-flash-thinking-exp",
Type: relaymode.ChatCompletions,
Owner: model.ModelOwnerGoogle,
},

{
Model: "text-embedding-004",
Expand Down
65 changes: 49 additions & 16 deletions service/aiproxy/relay/adaptor/gemini/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ func buildSafetySettings() []ChatSafetySettings {
{Category: "HARM_CATEGORY_HATE_SPEECH", Threshold: safetySetting},
{Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", Threshold: safetySetting},
{Category: "HARM_CATEGORY_DANGEROUS_CONTENT", Threshold: safetySetting},
{Category: "HARM_CATEGORY_CIVIC_INTEGRITY", Threshold: safetySetting},
}
}

Expand Down Expand Up @@ -237,10 +238,13 @@ func (g *ChatResponse) GetResponseText() string {
if g == nil {
return ""
}
if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 {
return g.Candidates[0].Content.Parts[0].Text
builder := strings.Builder{}
for _, candidate := range g.Candidates {
for _, part := range candidate.Content.Parts {
builder.WriteString(part.Text)
}
}
return ""
return builder.String()
}

type ChatCandidate struct {
Expand Down Expand Up @@ -283,9 +287,10 @@ func getToolCalls(candidate *ChatCandidate) []*model.Tool {
return toolCalls
}

func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse {
func responseGeminiChat2OpenAI(meta *meta.Meta, response *ChatResponse) *openai.TextResponse {
fullTextResponse := openai.TextResponse{
ID: "chatcmpl-" + random.GetUUID(),
Model: meta.OriginModelName,
Object: "chat.completion",
Created: time.Now().Unix(),
Choices: make([]*openai.TextResponseChoice, 0, len(response.Candidates)),
Expand All @@ -302,7 +307,14 @@ func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse {
if candidate.Content.Parts[0].FunctionCall != nil {
choice.Message.ToolCalls = getToolCalls(candidate)
} else {
choice.Message.Content = candidate.Content.Parts[0].Text
builder := strings.Builder{}
for i, part := range candidate.Content.Parts {
if i > 0 {
builder.WriteString("\n")
}
builder.WriteString(part.Text)
}
choice.Message.Content = builder.String()
}
} else {
choice.Message.Content = ""
Expand All @@ -314,16 +326,37 @@ func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse {
}

func streamResponseGeminiChat2OpenAI(meta *meta.Meta, geminiResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = geminiResponse.GetResponseText()
// choice.FinishReason = &constant.StopFinishReason
var response openai.ChatCompletionsStreamResponse
response.ID = "chatcmpl-" + random.GetUUID()
response.Created = time.Now().Unix()
response.Object = "chat.completion.chunk"
response.Model = meta.OriginModelName
response.Choices = []*openai.ChatCompletionsStreamResponseChoice{&choice}
return &response
response := &openai.ChatCompletionsStreamResponse{
ID: "chatcmpl-" + random.GetUUID(),
Created: time.Now().Unix(),
Model: meta.OriginModelName,
Object: "chat.completion.chunk",
Choices: make([]*openai.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates)),
}
for i, candidate := range geminiResponse.Candidates {
choice := openai.ChatCompletionsStreamResponseChoice{
Index: i,
}
if len(candidate.Content.Parts) > 0 {
if candidate.Content.Parts[0].FunctionCall != nil {
choice.Delta.ToolCalls = getToolCalls(candidate)
} else {
builder := strings.Builder{}
for i, part := range candidate.Content.Parts {
if i > 0 {
builder.WriteString("\n")
}
builder.WriteString(part.Text)
}
choice.Delta.Content = builder.String()
}
} else {
choice.Delta.Content = ""
choice.FinishReason = &candidate.FinishReason
}
response.Choices = append(response.Choices, &choice)
}
return response
}

func StreamHandler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model.Usage, *model.ErrorWithStatusCode) {
Expand Down Expand Up @@ -405,7 +438,7 @@ func Handler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model.Usage
if len(geminiResponse.Candidates) == 0 {
return nil, openai.ErrorWrapperWithMessage("No candidates returned", "gemini_error", resp.StatusCode)
}
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
fullTextResponse := responseGeminiChat2OpenAI(meta, &geminiResponse)
fullTextResponse.Model = meta.OriginModelName
respContent := []ChatContent{}
for _, candidate := range geminiResponse.Candidates {
Expand Down
2 changes: 1 addition & 1 deletion service/aiproxy/relay/adaptor/ollama/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func responseOllama2OpenAI(response *ChatResponse) *openai.TextResponse {
},
}
if response.Done {
choice.FinishReason = "stop"
choice.FinishReason = constant.StopFinishReason
}
fullTextResponse := openai.TextResponse{
ID: "chatcmpl-" + random.GetUUID(),
Expand Down
4 changes: 3 additions & 1 deletion service/aiproxy/relay/adaptor/openai/model.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package openai

import "github.com/labring/sealos/service/aiproxy/relay/model"
import (
"github.com/labring/sealos/service/aiproxy/relay/model"
)

type TextContent struct {
Type string `json:"type,omitempty"`
Expand Down

0 comments on commit c872d96

Please sign in to comment.